# GARstow portfile (i.e. Makefile) parser
# Copyright 2003, 2004, 2005, 2006, 2007, 2009, 2010, 2012 Adam Sampson <ats@offog.org>

import sys, os, re, StringIO, copy, subprocess

import cache
from config import *
from utils import *

class PortfileError(Exception): pass

@memoised
def parse_portfile(fn, mtime = 0):
	"""Parse the restricted subset of GNU Make syntax allowed in GARStow
	Makefiles."""

	s = ""
	lines = []
	f = open(fn, "r")
	for l in f.readlines():
		l = l[:-1]

		# Fold together continuations
		# FIXME: This logic is not correct -- GNU make collapses
		# whitespace around a line continuation down to a single space
		# in variable definitions, but passes continuations through to
		# the shell in rules.
		if len(l) > 0 and l[-1] == "\\":
			s += l[:-1]
		else:
			lines.append(s + l)
			s = ""
	f.close()

	parsed = []
	n = 0
	max = len(lines)
	while n < max:
		l = re.sub(r'\s+$', '', lines[n])

		# Remove comments
		l = re.sub(r'\s*#.*$', '', l)

		# Skip comments and blank lines
		if l == "":
			n += 1
			continue

		# Variable definitions
		m = re.match(r'^(\S+)\s*(=|\+=|:=|\?=)\s*(.*)$', l)
		if m is not None:
			kinds = {"=": "var", "+=": "varadd", ":=": "varsimple", "?=": "varmaybe"}
			parsed.append((kinds[m.group(2)], m.group(1), m.group(3)))
			n += 1
			continue

		# Long-form variable definitions
		m = re.match(r'^define\s*(\S+)$', l)
		if m is not None:
			name = m.group(1)
			content = ""
			while n < max:
				n += 1
				if lines[n] == "endef":
					n += 1
					break
				content += lines[n] + "\n"
			parsed.append(("var", name, content))
			continue

		# Inclusions and exports
		m = re.match(r'^(-?include|export|unexport)\s*(.*)$', l)
		if m is not None:
			parsed.append((m.group(1), m.group(2)))
			n += 1
			continue

		# Rules
		m = re.match(r'^(\S.*):(\s*.+)?$', l)
		if m is not None:
			name = m.group(1)
			n += 1
			commands = []
			while n < max:
				if lines[n].strip() == "":
					n += 1
					continue
				m = re.match(r'^\t(.*)$', lines[n])
				if m is None:
					break
				commands.append(m.group(1))
				n += 1
			parsed.append(("rule", name, commands))
			continue

		# ifdef/ifndef
		m = re.match(r'^(ifdef|ifndef)\s+(\S+)$', l)
		if m is not None:
			parsed.append((m.group(1), m.group(2)))
			n += 1
			continue

		# ifeq/ifneq
		# FIXME: this is too simple
		m = re.match(r'^(ifeq|ifneq)\s+"(.*)"\s+"(.*)"$', l)
		if m is not None:
			parsed.append((m.group(1), m.group(2), m.group(3)))
			n += 1
			continue

		# else/endif
		m = re.match(r'^(else|endif)$', l)
		if m is not None:
			parsed.append((m.group(1),))
			n += 1
			continue

		raise PortfileError("Unrecognised line: " + l)

	return parsed

def split_whitespace(s):
	return re.split(r'\s+', s)

def match_pats(s, pats):
	for p in pats:
		if p.startswith("%") and p[1:] == s:
			return True
		elif p.endswith("%") and p[:-1] == s:
			return True
		elif p == s:
			return True
	return False

class Portfile:
	def __init__(self, fn, **keys):
		self.vars = {}
		self.vartypes = {}
		self.varsources = {}
		self.rules = {}
		self.exports = {}
		self.load_errors = []
		self.providence = []

		self.filename = fn
		self.load(fn, **keys)

	def expand(self, s):
		obits = []
		o = 0
		max = len(s)
		while 1:
			i = s.find('$', o)
			if i == -1:
				obits.append(s[o:])
				break
			obits.append(s[o:i])
			if i + 1 >= max:
				raise PortfileError("Incomplete $ sequence in: " + s)

			c = s[i + 1]
			if c == "$":
				# $$
				obits.append("$")
				o = i + 2
				continue
			elif c == "(":
				# $(v)
				level = 1
				i += 2
				start = i
				while level > 0:
					nextl = s.find("(", i)
					nextr = s.find(")", i)
					if nextr == -1:
						raise PortfileError("Mismatched brackets in: " + s)
					elif nextl != -1 and nextl < nextr:
						level += 1
						i = nextl + 1
					else:
						level -= 1
						i = nextr + 1
				wanted = s[start:i - 1]
				o = i
			else:
				# $v
				wanted = c
				o = i + 2

			wanted = self.expand(wanted)
			i = wanted.find(" ")
			if i == -1:
				obits.append(self.variable(wanted))
			else:
				obits.append(self.function(wanted[:i], wanted[i + 1:]))
	
		return "".join(obits)

	def variable(self, name):
		if name in self.vars:
			return self.expand(self.vars[name])
		elif name in os.environ:
			return os.environ[name]
		else:
			return ""

	def shell_setup(self, with_env = False):
		if with_env:
			for e in self.exports.keys():
				os.environ[e] = self.variable(e)
		dir = os.path.dirname(self.filename)
		if dir != "":
			os.chdir(dir)

	def function(self, name, args):
		if name == "wildcard":
			name = "shell"
			args = "echo " + args
		if name == "shell":
			child = subprocess.Popen(args, stdout=subprocess.PIPE,
			                         preexec_fn=self.shell_setup,
			                         shell=True)
			v = child.communicate()[0]
			if v.endswith("\r\n"):
				v = v[:-2]
			elif v.endswith("\n"):
				v = v[:-1]
			v = v.replace("\r\n", " ")
			v = v.replace("\n", " ")
			return v
		elif name == "subst":
			a = args.split(",", 2)
			if len(a) < 3:
				raise PortfileError("Too few arguments to subst")
			return a[2].replace(a[0], a[1])
		elif name == "patsubst":
			a = args.split(",", 2)
			if len(a) < 3:
				raise PortfileError("Too few arguments to patsubst")
			# FIXME: this does not implement the full \ rules
			# (which we never use in GARStow)
			n = a[0].find("%")
			m = a[1].find("%")
			if n == -1 or m == -1:
				raise PortfileError("First two arguments to patsubst must contain %")
			oldprefix = a[0][:n]
			oldsuffix = a[0][n + 1:]
			newprefix = a[1][:m]
			newsuffix = a[1][m + 1:]
			words = []
			for word in a[2].split():
				if word.startswith(oldprefix) and word.endswith(oldsuffix) and len(word) > len(oldprefix) + len(oldsuffix):
					w = word[len(oldprefix):]
					if oldsuffix != "":
						w = w[:-len(oldsuffix)]
					words.append(newprefix + w + newsuffix)
				else:
					words.append(word)
			return " ".join(words)
		elif name == "addprefix":
			a = args.split(",", 1)
			if len(a) < 2:
				raise PortfileError("Too few arguments to addprefix")
			return " ".join([a[0] + x for x in split_whitespace(a[1])])
		elif name == "addsuffix":
			a = args.split(",", 1)
			if len(a) < 2:
				raise PortfileError("Too few arguments to addsuffix")
			return " ".join([x + a[0] for x in split_whitespace(a[1])])
		elif name == "if":
			a = args.split(",", 2)
			if len(a) < 2:
				raise PortfileError("Too few arguments to if")
			if a[0].strip() != "":
				return a[1]
			elif len(a) == 3:
				return a[2]
			else:
				return ""
		elif name == "filter":
			a = args.split(",", 1)
			pats = a[0].split()
			return " ".join([s for s in split_whitespace(a[1]) if match_pats(s, pats)])
		elif name == "filter-out":
			a = args.split(",", 1)
			pats = a[0].split()
			return " ".join([s for s in split_whitespace(a[1]) if not match_pats(s, pats)])
		elif name == "sort":
			return " ".join(sorted(set(args.split())))
		elif name == "strip":
			return args.strip()
		elif name in ("firstword", "lastword"):
			words = args.split()
			if len(words) == 0:
				return ""
			elif name == "firstword":
				return words[0]
			else:
				return words[-1]
		elif name == "words":
			return str(len(args.split()))
		elif name == "wordlist":
			a = args.split(",", 2)
			if len(a) < 3:
				raise PortfileError("Too few arguments to wordlist")
			try:
				first = int(a[0])
				last = int(a[1])
			except ValueError:
				raise PortfileError("Invalid arguments to wordlist")
			# The indices given are 1-based and inclusive.
			words = a[2].split()
			return " ".join(words[first - 1:last])
		else:
			raise PortfileError("Unknown function " + name)

	var_order = ["GARNAME", "GARVERSION", "GARREVISION",
	             "CATEGORIES", "NONFREE",
	             "MASTER_SITES", "MASTER_SUBDIR",
	             "DISTFILE_SITES", "DISTFILE_SUBDIR",
	             "SIGFILE_SITES", "SIGFILE_SUBDIR",
	             "PATCHFILE_SITES", "PATCHFILE_SUBDIR",
	             "UPSTREAMNAME", "DISTNAME", "DISTEXT",
	             "DISTFILES", "SIGFILES", "PATCHFILES",
	             "NOCHECKSUM",
	             "PATCHOPTS", "PATCHDIR",
	             "LIBDEPS", "BUILDDEPS",
	             "WORKSRC", "WORKOBJ",
	             "DESCRIPTION", "HOME_URL", "BLURB",
	             "CONFIGURE_SCRIPTS", "BUILD_SCRIPTS", "TEST_SCRIPTS", "INSTALL_SCRIPTS",
	             "CONFIGURE_ARGS", "BUILD_ARGS", "TEST_ARGS", "INSTALL_ARGS",
	             "CONFIGURE_ENV", "BUILD_ENV", "TEST_ENV", "INSTALL_ENV",
	             "NEED_USERS", "NEED_GROUPS",
	             "COLLISIONS", "DECONFLICT", "COMPATLIBS",
	             "CFLAGS", "CPPFLAGS"]
	rule_order = ["pre-extract", "post-extract",
	              "pre-patch", "post-patch",
	              "pre-configure", "configure-", "post-configure",
	              "pre-build", "build-", "post-build",
	              "pre-install", "install-", "post-install",
	              "pre-stow"]

	var_order_ = None
	rule_order_ = None
	def var_pos(self, s):
		if self.var_order_ is None:
			self.var_order_ = {}
			for i in range(len(self.var_order)):
				self.var_order_[self.var_order[i]] = i
		if s in self.var_order_:
			return self.var_order_[s]
		else:
			return -1
	def rule_pos(self, s):
		if self.rule_order_ is None:
			self.rule_order_ = {}
			for i in range(len(self.rule_order)):
				self.rule_order_[self.rule_order[i]] = i
		i = s.find("-")
		if s in self.rule_order_:
			return self.rule_order_[s]
		elif i != -1 and s[:i + 1] in self.rule_order_:
			return self.rule_order_[s[:i + 1]]
		else:
			return -1

	def load(self, fn, toplevel = True, is_include = False, top_fn = None, ignore_missing = False):
		norm_fn = os.path.normpath(fn)
		try:
			mtime = os.stat(norm_fn).st_mtime
		except OSError:
			# The file can't be stat-ed; probably doesn't exist.
			# But we need to record it in the providence anyway,
			# in case it gets created in the future.
			mtime = -1
		self.providence.append((norm_fn, mtime))

		if mtime == -1 and ignore_missing:
			return

		parsed = parse_portfile(norm_fn, mtime)

		if top_fn is None:
			top_fn = fn

		cond_stack = []
		seen_include = is_include
		var_max = -1
		varadd_max = -1
		rule_max = -1
		for p in parsed:
			if p[0] in ("ifdef", "ifndef"):
				name = self.expand(p[1])
				cond_stack.append(name in self.vars)
				if p[0] == "ifndef":
					cond_stack[-1] = not cond_stack[-1]
			elif p[0] in ("ifeq", "ifneq"):
				vala = self.expand(p[1])
				valb = self.expand(p[2])
				cond_stack.append(vala == valb)
				if p[0] == "ifneq":
					cond_stack[-1] = not cond_stack[-1]
			elif p[0] == "else":
				cond_stack[-1] = not cond_stack[-1]
			elif p[0] == "endif":
				cond_stack.pop()

			if False in cond_stack:
				continue

			if p[0] == "rule":
				name = self.expand(p[1])
				if toplevel and not seen_include:
					self.load_errors.append("Rule " + name + " declared before includes")
				pos = self.rule_pos(name)
				if pos != -1:
					if toplevel and pos < rule_max:
						self.load_errors.append("Rule " + name + " out of sequence")
					rule_max = pos
				if name in self.rules:
					self.load_errors.append("Rule " + name + " already defined")
				self.rules[name] = p[2]
			elif p[0] == "var":
				name = self.expand(p[1])
				if toplevel and seen_include and not is_include:
					self.load_errors.append("Variable " + name + " declared after includes")
				pos = self.var_pos(name)
				if pos != -1:
					if toplevel and pos < var_max:
						self.load_errors.append("Variable " + name + " out of sequence")
					var_max = pos
				if name in self.vars:
					self.load_errors.append("Variable " + name + " already defined")
				self.vars[name] = p[2]
				self.vartypes[name] = "normal"
				self.varsources[name] = fn
			elif p[0] == "varadd":
				name = self.expand(p[1])
				if toplevel and not seen_include:
					self.load_errors.append("Variable addition " + name + " declared before includes")
				pos = self.var_pos(name)
				if pos != -1:
					if toplevel and pos < varadd_max:
						self.load_errors.append("Variable addition " + name + " out of sequence")
					varadd_max = pos
				if not name in self.vars:
					self.vars[name] = ""
					self.vartypes[name] = "normal"
					self.varsources[name] = fn
				else:
					self.vars[name] += " "
				if self.vartypes[name] == "normal":
					self.vars[name] += p[2]
				else:
					self.vars[name] += self.expand(p[2])
			elif p[0] == "varsimple":
				name = self.expand(p[1])
				# Special case: don't complain if we're doing:
				# FOO := something:$(FOO)
				if name in self.vars and self.vartypes[name] == "simple" and p[2].find("$(" + name + ")") == -1:
					self.load_errors.append("Variable " + name + " already defined")
				self.vars[name] = self.expand(p[2])
				self.vartypes[name] = "simple"
				self.varsources[name] = fn
			elif p[0] == "varmaybe":
				name = self.expand(p[1])
				if not name in self.vars:
					self.vars[name] = p[2]
					self.vartypes[name] = "normal"
					self.varsources[name] = fn
			elif p[0] in ("include", "-include"):
				for ifn in split_whitespace(self.expand(p[1])):
					relname = os.path.join(os.path.dirname(top_fn), ifn)
					self.load(relname, False, top_fn = top_fn, ignore_missing = (p[0] == "-include"))
				seen_include = True
			elif p[0] == "export":
				for v in split_whitespace(self.expand(p[1])):
					self.exports[v] = 1
			elif p[0] == "unexport":
				for v in split_whitespace(self.expand(p[1])):
					if v in self.exports:
						del self.exports[v]
			elif p[0] not in ("ifdef", "ifndef", "ifeq", "ifneq", "else", "endif"):
				raise PortfileError("Unrecognised token: " + p[0])

	def validate(self, fn = None):
		if fn is None:
			fn = self.filename
		if self.load_errors != []:
			raise PortfileError(self.load_errors[0])
		must_define = ["GARNAME", "GARVERSION", "CATEGORIES",
		               "DESCRIPTION"]
		for v in must_define:
			if not v in self.vars:
				raise PortfileError("Variable " + v + " must be defined")
		if "INSTALL_SCRIPTS" in self.vars and not "DISTFILES" in self.vars:
			raise PortfileError("DISTFILES must be defined for non-stub port")
		if ("CONFIGURE_SCRIPTS" in self.vars or "BUILD_SCRIPTS" in self.vars) and not "INSTALL_SCRIPTS" in self.vars:
			raise PortfileError("INSTALL_SCRIPTS must be defined for buildable port")
		for stage in ["CONFIGURE", "BUILD", "INSTALL"]:
			if (stage + "_ARGS") in self.vars and not (stage + "_SCRIPTS") in self.vars and os.path.basename(self.varsources[stage + "_ARGS"]) == "Makefile":
				raise PortfileError(stage + "_ARGS specified without " + stage + "_SCRIPTS")
		must_not_be_empty = ["GARNAME", "GARVERSION", "DESCRIPTION"]
		for v in must_not_be_empty:
			if self.vars[v].strip() == "":
				raise PortfileError(v + " must not be empty")
		cats = self.vars["CATEGORIES"].strip().split()
		if len(cats) < 1:
			raise PortfileError("At least one category must be specified")
		if fn[0] != '/':
			fn = os.path.normpath(os.getcwd() + '/' + fn)
		cat_dn = os.path.dirname(os.path.dirname(fn))
		if cat_dn != "":
			cat = os.path.basename(cat_dn)
			if not cat in cats:
				raise PortfileError("Directory name " + cat + " is not listed in categories")
		if "BLURB" in self.vars and self.vars["BLURB"].strip() == "FIXME":
			raise PortfileError("BLURB set to FIXME -- replace with HOME_URL or remove")

	def show(self, f = sys.stdout):
		vs = self.vars.keys()
		vs.sort()
		for v in vs:
			print >>f, v + " = " + self.variable(v)
		print >>f
		print >>f, "export " + " ".join(self.exports.keys())
		print >>f
		rs = self.rules.keys()
		rs.sort()
		for r in rs:
			print >>f, r + ":"
			for l in self.rules[r]:
				print >>f, "\t" + self.expand(l)
			print >>f

def load_portfile(fn):
	"""Load a portfile, checking to see whether there's an up-to-date
	cached copy first."""
	port = cache.get(fn)
	if port is not None:
		for prov_fn, prov_mtime in port.providence:
			try:
				new_mtime = os.stat(prov_fn).st_mtime
				if new_mtime != prov_mtime:
					port = None
			except OSError:
				port = None
	if port is None:
		try:
			port = Portfile(fn, is_include = True)
		except PortfileError, p:
			die("Error loading portfile ", fn, ": ", p)
		cache.put(fn, port)
	return port

