#!/usr/bin/python
#
# rpmhack - tool to handle "patching" of RPM databases
#
#   Copyright (C) 2020 Olaf Kirch <okir@suse.de>
#
#   This program is free software; you can redistribute it and/or modify
#   it under the terms of the GNU General Public License as published by
#   the Free Software Foundation; either version 2 of the License, or
#   (at your option) any later version.
#
#   This program is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU General Public License for more details.
#
#   You should have received a copy of the GNU General Public License
#   along with this program; if not, write to the Free Software
#   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

import rpm
import rpmhack
import sys
import os
import exceptions

default_patchfile	= "rpmhack.patch"

opt_patchdir = None
opt_patchfile = None
opt_debug = False

class PatchFile:
	RPM_INSTALL	= 'i'
	RPM_UPGRADE	= 'u'
	RPM_REMOVE	= 'e'

	validModes = (RPM_INSTALL, RPM_UPGRADE, RPM_REMOVE)

	def __init__(self, path):
		self.path = path
		self._changes = []

	def changes(self):
		return self._changes

	def add(self, kind, hdrOrList):
		if type(hdrOrList) == list:
			pkgList = hdrOrList
			for pkg in pkgList:
				self.add(kind, pkg)
		else:
			self._changes.append((kind, hdrOrList))

	def write(self):
		import tarfile

		TF = tarfile.open(self.path, mode = "w:bz2")
		changeList = bytes()
		for change in self._changes:
			(kind, hdr) = change

			if kind == PatchFile.RPM_INSTALL or kind == PatchFile.RPM_UPGRADE:
				self.appendHeader(TF, hdr)

			changeList += "%s %s %s %s\n" % (kind, hdr.name, hdr.version, hdr.release)

		if not changeList:
			changeList = "# no changes"

		self.appendData(TF, "index", changeList)

	def appendHeader(self, TF, hdr):
		import tempfile

		tf = tempfile.TemporaryFile()
		outfile = rpm.fd(tf, "wb")
		hdr.write(outfile)

		outfile.seek(0)
		self.appendData(TF, hdr.NEVR, outfile.read())

	def appendData(self, TF, name, data):
		import tarfile
		import io

		# print("%s is %u bytes of data" % (name, len(data)))

		tarInfo = tarfile.TarInfo(name)
		tarInfo.size = len(data)
		TF.addfile(tarInfo, io.BytesIO(data))

	def read(self):
		import tarfile

		TF = tarfile.open(self.path, mode = "r")

		self.readIndex(TF)

	def readIndex(self, TF):
		memberName = "%s[index]" % self.path
		rawIndex = TF.extractfile("index")

		self._changes = []
		lineno = 0

		for line in rawIndex.readlines():
			lineno += 1

			if line[0] == '#':
				continue

			x = line.split()
			hdr = rpm.hdr()
			kind = x[0]
			hdr['name'] = x[1]
			hdr['version'] = x[2]
			hdr['release'] = x[3]
			if len(x) > 4:
				hdr['epoch'] = x[4]

			if kind not in PatchFile.validModes:
				raise exceptions.ValueError("%s:%s: invalid operation: %s" % (memberName, lineno, line))

			if kind == PatchFile.RPM_INSTALL or kind == PatchFile.RPM_UPGRADE:
				hdr = self.readHeader(TF, hdr.NEVR)

			self._changes.append([kind, hdr])

	def readHeader(self, TF, name):
		import tempfile

		memberName = "%s[%s]" % (self.path, name)
		rawHeader = TF.extractfile(name)

		tf = tempfile.TemporaryFile()

		tf.write(rawHeader.read())

		tf.seek(0)
		return rpm.hdr(tf)

class TransactionSet:
	def __init__(self, dbpath, verifySignatures = False):
		sigFlags = 0
		if not verifySignatures:
			sigFlags = rpm._RPMVSF_NOSIGNATURES

		self._dbpath = dbpath
		self._ts = self.openDB(sigFlags)

		tsFlags = rpm.RPMTRANS_FLAG_JUSTDB
		tsFlags |= rpm.RPMTRANS_FLAG_NOSCRIPTS
		tsFlags |= rpm.RPMTRANS_FLAG_NOTRIGGERS
		tsFlags |= rpm.RPMTRANS_FLAG_TEST
		self._ts.setFlags(tsFlags)

		self._headers = {}

	def openDB(self, sigFlags = 0):
		global opt_rootdir

		path = self._dbpath
		if path is not None:
			message("Opening DB %s" % path)
			rpm.addMacro('_dbpath', path)
		elif opt_rootdir:
			message("Opening default DB (rootdir %s)" % opt_rootdir)
		else:
			message("Opening default DB")

		ts = rpm.TransactionSet(opt_rootdir, sigFlags)
		ts.openDB()
		rpm.delMacro('_dbpath')

		return ts

	def add(self, kind, hdr):
		debug("--- %s %s" % (kind, hdr.name))
		if kind == PatchFile.RPM_REMOVE:
			self._ts.addErase(hdr.name)
			return

		if self._findPackage(hdr.name, hdr.NEVR):
			message("%s already listed in %s - nothing to do" % (hdr.NEVR, self._dbpath or "RPM DB"))
			return

		if kind == PatchFile.RPM_INSTALL:
			self._ts.addInstall(hdr, None, 'i')
		elif kind == PatchFile.RPM_UPGRADE:
			self._ts.addInstall(hdr, None, 'u')
		else:
			error("%s: don't know how to handle operation type %s" % (self.__class__.__name__, kind))
			return

		self._headers[hdr.NEVR] = hdr

	def isEmpty(self):
		for te in self._ts:
			return False
		return True

	def order(self):
		self._ts.order()

	def check(self):
		return self._ts.check()

	def show(self, verbose = False):
		for te in self._ts:
			if te.Type() == rpm.TR_ADDED:
				how = "add"
			else:
				how = "rem"

			if verbose:
				print("%s %s" % (how, te.NEVR()))
				fi = te.FI()
				for i in range(len(fi)):
					print("  %s" % fi[i])
			else:
				print("%s %s (%u files)" % (how, te.NEVR(), len(te.FI())))

	def apply(self):
		db = rpmdbOpen(self._dbpath, mode = "w")

		problems = []
		for te in self._ts:
			if te.Type() == rpm.TR_ADDED:
				debug("Trying to add %s" % te.NEVR())
				h = self._headers.get(te.NEVR())
				if h is None:
					problems.append("Cannot add %s: no header info provided" % te.NEVR())
					continue

				db.add(h)
			else:
				debug("Trying to remove %s" % te.NEVR())
				teInstance = self._findInstance(te.N(), te.NEVR())
				if teInstance is None:
					problems.append("Cannot remove %s: unable to find DB instance" % te.NEVR())
					continue

				db.remove(teInstance)

		return problems

	def printIssues(self, list):
		for t in list:
			(pkgNVR, reqNV, needFlags, suggestedPkg, senseFlags) = t
			pkg = "%s-%s-%s" % pkgNVR

			reqRelation = "="
			if needFlags & rpm.RPMSENSE_LESS:
				reqRelation = "<"
			if needFlags & rpm.RPMSENSE_GREATER:
				reqRelation = ">"
			if needFlags & rpm.RPMSENSE_EQUAL:
				reqRelation += "="

			if reqRelation == "==":
				reqRelation = "-"
			else:
				reqRelation = " %s " % reqRelation

			if senseFlags == rpm.RPMDEP_SENSE_CONFLICTS:
				sense = "conflicts with"
			elif senseFlags == rpm.RPMDEP_SENSE_REQUIRES:
				sense = "requires"
			else:
				sense = "<unknown relation 0x%x>" % senseFlags

			print("%s %s %s%s%s" % (pkg, sense, reqNV[0], reqRelation, reqNV[1]))


	def _findPackage(self, name, NEVR):
		mi = self._ts.dbMatch('name', name)
		try:
			while mi:
				h = mi.next()
				if h.NEVR == NEVR:
					return h
		except StopIteration:
			pass

		return None

	def _findInstance(self, name, NEVR):
		mi = self._ts.dbMatch('name', name)
		try:
			while mi:
				h = mi.next()
				if h.NEVR == NEVR:
					return mi.instance()
		except StopIteration:
			pass

		return None

def rpmdbOpen(dbpath, mode = None):
	if dbpath is not None:
		if mode is None or not mode.startswith('w'):
			if not os.path.isdir(dbpath):
				error("rpmdbOpen: %s does not seem to exist" % dbpath)
				raise exceptions.ValueError("rpmdbOpen - invalid database?");
		rpm.addMacro('_dbpath', dbpath)
	db = rpmhack.db(mode = mode)
	rpm.delMacro('_dbpath')

	return db

def error(msg):
	sys.stdout.flush()

	sys.stderr.write("Error: ")
	sys.stderr.write(msg)
	if msg[-1] != '\n':
		sys.stderr.write('\n')

def message(msg):
	sys.stdout.write(msg)
	if msg[-1] != '\n':
		sys.stdout.write('\n')

def debug(msg):
	global opt_debug

	if opt_debug:
		sys.stderr.write("D: ")
		sys.stderr.write(msg)
		if msg[-1] != '\n':
			sys.stderr.write('\n')

def preparseOptionParser(usageMsg):
	import optparse

        parser = optparse.OptionParser(usage = usageMsg)

	parser.add_option('--root-dir', default = "/",
		help = "Installation root directory")
	parser.add_option('--debug', default = False, action = 'store_true',
		help = "Enable debugging output from RPM library")

	return parser

def processCommonOptions(opts):
	global opt_rootdir
	global opt_debug

	if opts.root_dir:
		opt_rootdir = opts.root_dir

	if opts.debug:
		rpm.setVerbosity(rpm.RPMLOG_DEBUG)
		opt_debug = True

def diffParseOptions():
	global opt_patchfile

	p = preparseOptionParser("rpmhack diff [global options] origDB [newDB]")
	p.add_option('--output', default = default_patchfile,
		help = "File to write patch information to. Default: %s" % default_patchfile)

	(opts, args) = p.parse_args()

	processCommonOptions(opts)
	opt_patchfile = opts.output

	if len(args) not in (1, 2):
		parser.error("Invalid number of arguments")

	return args

def patchParseOptions():
	global opt_patchdir
	global opt_dbpath

	p = preparseOptionParser("rpmhack patch [global options] patchFile")
	p.add_option('--dbpath', default = None,
		help = "Path to the RPM database to modify. Default to system RPM DB")

	(opts, args) = p.parse_args()

	processCommonOptions(opts)
	opt_dbpath = opts.dbpath

	if len(args) != 1:
		parser.error("Invalid number of arguments")

	return args

def getPackages(dbpath):
	db = rpmdbOpen(dbpath)

	pp = db.packages()

	res = dict()
	for h in pp:
		res[h.name] = h

	return res

def prettyName(pkg):
	return "%s-%s" % (pkg.name, pkg.version)

def doDiff():
	global opt_patchfile

	args = diffParseOptions()

	origPackages = getPackages(args[0])
	if len(args) == 2:
		newPackages = getPackages(args[1])
	else:
		newPackages = getPackages()

	added = []
	removed = []
	upgraded = []

	for p in origPackages.values():
		newP = newPackages.get(p.name)
		if newP is None:
			message("Removed %s" % prettyName(p))
			removed.append(p)
		elif p.version != newP.version \
		  or p.release != newP.release \
		  or p.epoch != newP.epoch:
			message("Upgraded %s" % prettyName(p))
			upgraded.append(newP)

	for p in newPackages.values():
		oldP = origPackages.get(p.name)
		if oldP is None:
			message("Added %s" % prettyName(p))
			added.append(p)

	pf = PatchFile(opt_patchfile)
	pf.add(PatchFile.RPM_INSTALL, added)
	pf.add(PatchFile.RPM_UPGRADE, upgraded)
	pf.add(PatchFile.RPM_REMOVE, removed)
	pf.write()

	if not pf.changes:
		message("Nothing changed; created empty patch")

	message("Wrote patch information to %s" % opt_patchfile)

def doPatch():
	global opt_dbpath

	args = patchParseOptions()

	pf = PatchFile(args[0])
	pf.read()

	debug("Creating TransactionSet")
	ts = TransactionSet(opt_dbpath)
	for change in pf.changes():
		(kind, hdr) = change

		ts.add(kind, hdr)

	if ts.isEmpty():
		message("Empty RPM patch; nothing to do")
		return 0;

	debug("Ordering TransactionSet")
	ts.order()

	debug("Checking TransactionSet")
	issues = ts.check()
	if issues:
		error("Unable to apply RPM transaction set")
		ts.printIssues(issues)
		sys.exit(1)

	ts.show()

	debug("Applying TransactionSet")
	problems = ts.apply()
	if problems:
		error("Problems applying patch")
		for p in problems:
			message("  %s" % p)
		sys.exit(1)

	message("Successfully applied patch")
	sys.exit(0)

prog = os.path.basename(sys.argv[0])

if prog.startswith("rpmhack-"):
	mode = prog[6:]
else:
	mode = sys.argv[1]
	del sys.argv[1]
        # For error reporting
	sys.argv[0] = "rpmhack-" + mode

if mode == "diff":
	doDiff()
elif mode == "patch":
	doPatch()
else:
	error("rpmhack: unknown operation mode %s" % mode)
	sys.exit(1)
