# m25: a simple macro processor.
# Copyright 2003, 2011, 2017, 2024 Adam Sampson <ats@offog.org>
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import os
import re
import sys

if sys.version_info.major > 2:
	from io import StringIO
else:
	from cStringIO import StringIO

class M25Error(Exception):
	"""Any error raised during M25 processing."""
	def __init__(self, text, token = None, pos = None):
		if pos is not None:
			text = str(pos) + ": " + text
		if token is not None:
			text = text + " " + repr(token)
		Exception.__init__(self, text)

class Position:
	"""A position within an input stream."""
	def __init__(self, filename = None, line = 1, char = 1):
		if filename is None:
			self.filename = "(data)"
		else:
			self.filename = filename
		self.line = line
		self.char = char

	def copy(self):
		return Position(self.filename, self.line, self.char)

	def advance(self, c):
		if c == "\n":
			self.line += 1
			self.char = 1
		else:
			self.char += 1

	def __repr__(self):
		return "Position(%s, %s, %s)" % (repr(self.filename), repr(self.line), repr(self.char))

	def __str__(self):
		return "%s:%d:%d" % (self.filename, self.line, self.char)

class Context:
	"""A set of variables."""
	def __init__(self, parent = None):
		self.vars = {}
		self.parent = parent

		if parent is None:
			self.define("define", (["p", "v"], define_command))
			self.define("redefine", (["p", "v"], redefine_command))
			self.define("let", (["p", "v"], let_command))
			self.define("set", (["p", "v"], set_command))
			self.define("if", (["condition", "true"], if_command))
			self.define("if_else", (["condition", "true", "false"], if_else_command))
			self.define("expr", (["expression"], expr_command))
			self.define("system", (["command"], system_command))
			self.define("match", (["regexp", "string"], match_command))
			self.define("replace", (["regexp", "replacement", "string"], replace_command))
			self.define("ignore", (["thing"], ignore_command))
			self.define("strip", (["string"], strip_command))
			self.define("literal", (["string"], literal_command))
			self.define("include", (["file"], include_command))

	def define(self, key, value):
		"""Define a new variable at the current scope."""
		self.vars[key] = value

	def define_literal(self, key, value):
		"""Define a new variable at the current scope that expands to a literal string."""
		self.vars[key] = ([], make_literal_command(value))

	def set(self, key, value):
		"""Update the value of a variable at the highest scope."""
		if key in self.vars:
			self.vars[key] = value
		elif self.parent is not None:
			self.parent.set(key, value)
		else:
			raise M25Error("Set unknown variable", key)

	def get(self, key):
		"""Get the value of a variable at the highest scope."""
		if key in self.vars:
			return self.vars[key]
		elif self.parent is not None:
			return self.parent.get(key)
		else:
			raise M25Error("Get unknown variable", key)

class Tokeniser:
	"""Tokenise M25 input."""

	identifier_chars = "ABCDEFGHJIKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-"
	whitespace_chars = " \t"
	meta_chars = "\\{"

	def __init__(self, s, pos):
		self.s = s
		self.pos = pos.copy()
		self.len = len(s)
		self.i = 0

	def advance(self, count = 1):
		for i in range(count):
			self.pos.advance(self.s[self.i])
			self.i += 1

	def discard_whitespace(self):
		while self.i != self.len:
			if self.s[self.i:self.i + 2] == "\\\n":
				self.advance(2)
			elif self.s[self.i] in self.whitespace_chars:
				self.advance()
			else:
				break

	def get_identifier(self):
		pos = self.pos.copy()
		start = self.i
		while self.i != self.len and self.s[self.i] in self.identifier_chars:
			self.advance()
		end = self.i
		self.discard_whitespace()
		return ('identifier', self.s[start:end], pos)

	def get_bracketed(self):
		pos = self.pos.copy()
		start = self.i
		level = 1
		while self.i < self.len and level > 0:
			c = self.s[self.i]
			self.advance()
			if c == '\\':
				# Skip \{ and \}.
				self.advance()
			elif c == '{':
				level += 1
			elif c == '}':
				level -= 1
		end = self.i - 1
		self.discard_whitespace()
		return ('bracketed', self.s[start:end], pos)

	def get(self):
		"""Get the next token from the input. Returns a pair (type,
		token, position). The position is that of the start of the
		token."""

		if self.i == self.len:
			return ('eof', None, self.pos.copy())

		start = self.i
		end = self.len
		for c in self.meta_chars:
			pos = self.s.find(c, start)
			if pos != -1 and pos < end:
				end = pos
		if end != self.i:
			pos = self.pos.copy()
			self.advance(end - self.i)
			return ('string', self.s[start:end], pos)

		c = self.s[self.i]
		self.advance()
		if self.i == self.len:
			raise M25Error("Metacharacter at end of file", pos = self.pos)

		if c == '\\':
			c = self.s[self.i]
			if c in self.identifier_chars:
				return self.get_identifier()
			elif c == '\n':
				self.advance()
				return self.get()
			else:
				pos = self.pos.copy()
				self.advance()
				return ('string', c, pos)
		elif c == '{':
			return self.get_bracketed()
		else:
			raise M25Error("Internal error: metacharacter not handled", pos = self.pos)

def expand_arg(arg, out, ctx):
	return expand(arg[0], out, ctx, arg[1])
def expand_arg_string(arg, ctx):
	return expand_string(arg[0], ctx, arg[1])

def define_command(ctx, pos, args, out, use_set = 0, expand = 0):
	"""The define, redefine, let and set commands."""
	proto = args["p"][0].split(" ")
	if len(proto) < 1:
		raise M25Error("Must give name for define/set", pos = pos)

	if expand:
		value = expand_arg_string(args["v"], ctx)
		valuepos = Position()
	else:
		value = args["v"][0]
		valuepos = args["v"][1]

	name = proto[0]
	args = proto[1:]
	expansion = lambda c, p, a, r, v = value, vp = valuepos: macro_command(c, p, a, r, v, vp)

	if use_set:
		ctx.set(name, (args, expansion))
	else:
		ctx.define(name, (args, expansion))

def redefine_command(ctx, pos, args, out):
	return define_command(ctx, pos, args, out, 1, 0)
def let_command(ctx, pos, args, out):
	return define_command(ctx, pos, args, out, 0, 1)
def set_command(ctx, pos, args, out):
	return define_command(ctx, pos, args, out, 1, 1)

def if_command(ctx, pos, args, out):
	cond = expand_arg_string(args["condition"], ctx)
	if eval(cond):
		expand_arg(args["true"], out, ctx)

def if_else_command(ctx, pos, args, out):
	cond = expand_arg_string(args["condition"], ctx)
	if eval(cond):
		expand_arg(args["true"], out, ctx)
	else:
		expand_arg(args["false"], out, ctx)

def expr_command(ctx, pos, args, out):
	expression = expand_arg_string(args["expression"], ctx)
	out.write(str(eval(expression)))

def system_command(ctx, pos, args, out):
	command = expand_arg_string(args["command"], ctx)
	cmd = ["/bin/sh", "-c", command]
	(pr, pw) = os.pipe()
	pid = os.fork()
	if pid == 0:
		os.close(pr)
		os.dup2(pw, 1)
		os.execvp(cmd[0], cmd)
		os._exit(20)
	os.close(pw)
	f = os.fdopen(pr)
	out.write(f.read())
	f.close()
	(deadpid, stat) = os.waitpid(pid, 0)
	ctx.define_literal("system_exit_status", str(stat))

regexp_cache = {}
def get_regexp(regexp):
	"""Compile a regexp, using the cache if possible."""
	global regexp_cache
	if regexp not in regexp_cache:
		regexp_cache[regexp] = re.compile(regexp, re.MULTILINE)
	return regexp_cache[regexp]

def match_command(ctx, pos, args, out):
	r = get_regexp(args["regexp"][0])
	match = r.search(expand_arg_string(args["string"], ctx))
	if match is None:
		out.write("0")
	else:
		groups = match.groups()
		for i in range(len(groups)):
			ctx.define_literal(str(i + 1), groups[i])
		out.write("1")

def replace_command(ctx, pos, args, out):
	r = get_regexp(args["regexp"][0])
	string = expand_arg_string(args["string"], ctx)
	def rep(match):
		newctx = Context(ctx)
		groups = match.groups()
		for i in range(len(groups)):
			newctx.define_literal(str(i + 1), groups[i])
		return expand_arg_string(args["replacement"], newctx)
	out.write(r.sub(rep, string))

def ignore_command(ctx, pos, args, out):
	class Sink:
		def write(self, s): pass
	expand_arg(args["thing"], Sink(), ctx)

def strip_command(ctx, pos, args, out):
	out.write(expand_arg_string(args["string"], ctx).strip())

def literal_command(ctx, pos, args, out):
	out.write(args["string"][0])

def include_command(ctx, pos, args, out):
	expand_file(args["file"][0], out, ctx)

def make_literal_command(value):
	return lambda c, p, a, out, v = value: out.write(v)

def macro_command(ctx, pos, args, out, text, textpos):
	newctx = Context(ctx)
	for arg in args.keys():
		expval = expand_arg_string(args[arg], ctx)
		newctx.define_literal(arg, expval)
	expand(text, out, newctx, textpos)

def expand(s, out, ctx = None, pos = None):
	"""Expand an M25 string in the given context to the given output file."""
	if ctx is None:
		ctx = Context()
	if pos is None:
		pos = Position()

	t = Tokeniser(s, pos)
	while 1:
		(type, token, pos) = t.get()
		if type == 'eof':
			break
		elif type == 'string':
			out.write(token)
		elif type == 'bracketed':
			expand(token, out, Context(ctx), pos)
		elif type == 'identifier':
			(argnames, expansion) = ctx.get(token)
			args = {}
			for arg in argnames:
				(argtype, argtoken, argpos) = t.get()
				if argtype != 'bracketed':
					raise M25Error("Argument expected to " + token, pos = argpos)
				args[arg] = (argtoken, argpos)
			expansion(ctx, pos, args, out)
		else:
			pass

def expand_file(filename, out, ctx = None):
	"""Expand an M25 file in the given context to the given output file."""
	f = open(filename, "r")
	expand(f.read(), out, ctx, Position(filename))
	f.close()

def expand_string(s, ctx = None, pos = None):
	"""Expand an M25 string in the given context to a string."""
	r = StringIO()
	expand(s, r, ctx, pos)
	result = r.getvalue()
	r.close()
	return result
