#!/usr/libexec/platform-python

# Tool creating a maven repository out of rpms built by OBS
# Copyright (C) 2019  SUSE Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License

import argparse
from datetime import datetime
import errno
import logging
import os
import os.path
import re
import shutil
import sys
import subprocess
import tempfile
import urllib.request
import xml.etree.ElementTree as ET
import yaml
import zlib


logging.basicConfig(level=logging.INFO)


# RPM version comparison shamelessly taken from:
# https://stackoverflow.com/questions/3206319/how-do-i-compare-rpm-versions-in-python/42967591#42967591
try:
    from rpm import labelCompare as _compare_rpm_labels
except ImportError:
    # Emulate RPM field comparisons
    #
    # * Search each string for alphabetic fields [a-zA-Z]+ and
    #   numeric fields [0-9]+ separated by junk [^a-zA-Z0-9]*.
    # * Successive fields in each string are compared to each other.
    # * Alphabetic sections are compared lexicographically, and the
    #   numeric sections are compared numerically.
    # * In the case of a mismatch where one field is numeric and one is
    #   alphabetic, the numeric field is always considered greater (newer).
    # * In the case where one string runs out of fields, the other is always
    #   considered greater (newer).

    logging.warning("Failed to import 'rpm', emulating RPM label comparisons")

    from itertools import zip_longest

    _subfield_pattern = re.compile(
        r'(?P<junk>[^a-zA-Z0-9]*)((?P<text>[a-zA-Z]+)|(?P<num>[0-9]+))'
    )

    def _iter_rpm_subfields(field):
        """Yield subfields as 2-tuples that sort in the desired order

        Text subfields are yielded as (0, text_value)
        Numeric subfields are yielded as (1, int_value)
        """
        for subfield in _subfield_pattern.finditer(field):
            text = subfield.group('text')
            if text is not None:
                yield (0, text)
            else:
                yield (1, int(subfield.group('num')))

    def _compare_rpm_field(lhs, rhs):
        # Short circuit for exact matches (including both being None)
        if lhs == rhs:
            return 0
        # Otherwise assume both inputs are strings
        lhs_subfields = _iter_rpm_subfields(lhs)
        rhs_subfields = _iter_rpm_subfields(rhs)
        for lhs_sf, rhs_sf in zip_longest(lhs_subfields, rhs_subfields):
            if lhs_sf == rhs_sf:
                # When both subfields are the same, move to next subfield
                continue
            if lhs_sf is None:
                # Fewer subfields in LHS, so it's less than/older than RHS
                return -1
            if rhs_sf is None:
                # More subfields in LHS, so it's greater than/newer than RHS
                return 1
            # Found a differing subfield, so it determines the relative order
            return -1 if lhs_sf < rhs_sf else 1
        # No relevant differences found between LHS and RHS
        return 0


    def _compare_rpm_labels(lhs, rhs):
        lhs_epoch, lhs_version, lhs_release = lhs
        rhs_epoch, rhs_version, rhs_release = rhs
        result = _compare_rpm_field(lhs_epoch, rhs_epoch)
        if result:
            return result
        result = _compare_rpm_field(lhs_version, rhs_version)
        if result:
            return result
        return _compare_rpm_field(lhs_release, rhs_release)


class Rpm:
    def __init__(self, location, mtime, name, version_node):
        self.path = location
        self.mtime = mtime
        self.name = location[location.find('/') + 1:]
        self.pkgname = name
        self.epoch = version_node.get('epoch')
        self.version = version_node.get('ver')
        self.release = version_node.get('rel')

    def compare(self, other):
        return _compare_rpm_labels(
                (other.epoch, other.version, other.release),
                (self.epoch, self.version, self.release))


class Repo:
    def __init__(self, uri, project, repository):
        self.uri = uri
        self.project = project.replace(':', ':/')
        self.repository = repository
        self._rpms = None

    def find_primary(self):
        ns = {'repo': 'http://linux.duke.edu/metadata/repo', 'rpm': 'http://linux.duke.edu/metadata/rpm'}
        f = urllib.request.urlopen(
            "{}/{}/{}/repodata/repomd.xml".format(self.uri, self.project, self.repository)
        )
        doc = ET.fromstring(f.read())
        primary_href = doc.find("./repo:data[@type='primary']/repo:location", ns).get("href")
        return "{}/{}/{}/{}".format(self.uri, self.project, self.repository, primary_href)

    def parse_primary(self):
        ns = {"c": "http://linux.duke.edu/metadata/common", "rpm":"http://linux.duke.edu/metadata/rpm"}
        f = urllib.request.urlopen(self.find_primary())
        primary_xml = zlib.decompress(f.read(), 16 + zlib.MAX_WBITS)
        doc = ET.fromstring(primary_xml)

        all_rpms = [
            Rpm(n.find("c:location", ns).get("href"),
                int(n.find("c:time", ns).get("file")),

                n.find("c:name", ns).text,
                n.find("c:version", ns))
            for n
            in doc.findall(".//c:package", ns)
            if n.find("c:arch", ns).text in ["x86_64", "noarch"]
        ]
        latest_rpms = {}
        for rpm in all_rpms:
            latest = latest_rpms.get(rpm.pkgname)
            if latest is None or latest.compare(rpm):
                latest_rpms[rpm.pkgname] = rpm
        self._rpms = latest_rpms.values()

    @property
    def rpms(self):
        if not self._rpms:
            self.parse_primary()
        return self._rpms

    def get_binary(self, path, target, mtime):
        """
        Equivalent of osc.core.get_binary_file
        """
        f = urllib.request.urlopen("{}/{}/{}/{}".format(self.uri, self.project, self.repository, path))
        target_f = open(target, 'wb')
        shutil.copyfileobj(f, target_f)
        target_f.close()
        f.close()
        os.utime(target, (mtime, mtime))


class Artifact:
    def __init__(self, config, repositories, group):
        self.artifact = config['artifact']
        self.group = config.get('group', group)
        self.package = config.get('rpm') or config.get('package', self.artifact)
        if config.get('rpm'):
            logging.warning('artifact "rpm" property is deprecated')
        self.arch = config.get('arch', 'noarch')
        self.repository = repositories.get(config['repository'])
        if not self.repository:
            raise RuntimeError('Missing repository definition: ' + config['repository'])
        self.jar = config.get('jar')

    def get_binary(self):
        file_pattern = self.package if self.package.endswith("-") else self.package + "-[0-9]"
        excluded = ['javadoc', 'examples', 'manual', 'test', 'demo']
        filtered = [
            file
            for file in self.repository.rpms
            if not bool([pattern for pattern in excluded if pattern in file.name]) and
                re.match(file_pattern, file.name)
        ]

        if len(filtered) > 1:
            raise RuntimeError('Found more than one file for {}:\n  {}'.format(
                self.artifact, "\n  ".join([file.name for file in filtered])))

        if len(filtered) == 0:
            raise RuntimeError('Found no file matching "{}" for {}'.format(file_pattern, self.artifact))
        return filtered[0]

    def fetch_binary(self, file, tmp):
        target_file = os.path.join(tmp, file.name)
        logging.info('Downloading %s' % target_file)
        self.repository.get_binary(file.path, target_file, file.mtime)
        if os.path.isfile(target_file):
            return target_file
        return None

    def extract(self, rpm_file, tmp):
        # Get the rpm tags
        rpm_process = subprocess.Popen(['rpm', '-q', '--xml', '-p', rpm_file], stdout=subprocess.PIPE)
        rpm_tags = ET.fromstring(rpm_process.communicate()[0])

        # Get the file links list and package version
        links = [node.text for node in rpm_tags.findall('.//rpmTag[@name="Filelinktos"]/*')]
        version = rpm_tags.find(".//rpmTag[@name='Version']/").text

        # Get the files
        rpm_process = subprocess.Popen(['rpm', '-qlp', rpm_file], stdout=subprocess.PIPE)
        files = rpm_process.communicate()[0].splitlines()

        not_linked = [f.decode('utf-8') for (i, f) in enumerate(files) if links[i] is None]

        logging.debug('not linked:\n  %s' % '\n  '.join(not_linked))

        pattern = self.jar if self.jar is not None else self.artifact
        end_pattern = r'[^/]*\.jar' if self.jar is None or not self.jar.endswith('.jar') else ''
        full_pattern = '^/usr/.*/%s%s$' % (pattern, end_pattern)

        logging.debug('full pattern: %s' % full_pattern)

        jars = [f for f in not_linked if re.match('^/usr/.*/{}'.format(end_pattern), f)]
        if len(jars) == 0:
            raise RuntimeError('Found no jar to extract in ' + rpm_file)
        elif len(jars) > 1:
            to_extract = [f for f in jars if re.match(full_pattern, f)]
            if len(to_extract) == 0:
                raise RuntimeError('Found no jar matching {} in {}'.format(pattern, rpm_file))
            elif len(to_extract) > 1:
                raise RuntimeError('Found several jars to extract in {}:\n  {}'.format(
                    rpm_file, '\n  '.join(to_extract)))
            jar_entry = to_extract[0]
        else:
            jar_entry = jars[0]

        # try harder to guess the version number from the jar file since it may be different from the rpm
        matcher = re.search('%s-([0-9.]+).jar' % self.artifact, os.path.basename(jar_entry))
        if matcher:
            version = matcher.group(1)

        dst_path = os.path.join(tmp, os.path.basename(jar_entry))
        logging.info('extracting %s to %s' % (jar_entry, dst_path))

        old_pwd = os.getcwd()
        os.chdir(tmp)
        rpm2cpio = subprocess.Popen(('rpm2cpio', rpm_file), stdout=subprocess.PIPE)
        dst = open(dst_path, 'wb')
        cpio = subprocess.Popen(('cpio', '-id', '.' + jar_entry),
                                stdin=rpm2cpio.stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        ret = cpio.wait()
        dst.close()
        os.chdir(old_pwd)

        if ret != 0:
            raise RuntimeError("Failed to extract jar file {}: {}".format(jar_entry, cpio.communicate()[1]))
        shutil.copy(os.path.join(tmp, '.' + jar_entry), dst_path)
        shutil.rmtree(os.path.join(tmp, 'usr'))

        return (dst_path, version)


    def deploy(self, jar, version, repo, mtime):
        artifact_folder = os.path.join(repo, self.group, self.artifact)
        try:
            os.makedirs(os.path.join(artifact_folder, version))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise
        jar_path = os.path.join(artifact_folder, version, '%s-%s.jar' % (self.artifact, version))
        logging.info('deploying %s to %s' % (jar, jar_path))
        shutil.copyfile(jar, jar_path)
        logging.debug('Setting mtime %d on %s' %(mtime, jar_path))
        os.utime(jar_path, (mtime, mtime))

        pom = """<?xml version="1.0" encoding="UTF-8"?>
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
         http://maven.apache.org/xsd/maven-4.0.0.xsd"
         xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
  <modelVersion>4.0.0</modelVersion>
  <groupId>%s</groupId>
  <artifactId>%s</artifactId>
  <version>%s</version>
  <description>POM was created from obs-to-maven</description>
</project>""" % (self.group, self.artifact, version)
        with open(os.path.join(artifact_folder, version, '%s-%s.pom' % (self.artifact, version)), 'w') as fd:
            fd.write(pom)

        # Maintain metadata file repo/group/artifact/maven-metadata-local.xml
        metadata_path = os.path.join(artifact_folder, "maven-metadata-local.xml")
        update_time = datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')
        write_anew = True
        if os.path.isfile(metadata_path):
            doc = ET.parse(metadata_path)
            versions_node = doc.find('versions')
            lastupdated_node = doc.find('lastUpdated')
            if versions_node is None or lastupdated_node is None:
                logging.warning('Invalid XML file: creating a new one: %s' % metadata_path)
            else:
                version_node = ET.SubElement(versions_node, 'version')
                version_node.text = version
                lastupdated_node.text = update_time
                doc.write(metadata_path)
                write_anew = False

        # Something wrong happened or we don't have the file: create a fresh one
        if write_anew:
            xml = """<?xml version="1.0" encoding="UTF-8"?>
<metadata>
  <groupId>%s</groupId>
  <artifactId>%s</artifactId>
  <versioning>
    <release>%s</release>
    <versions>
      <version>%s</version>
    </versions>
    <lastUpdated>%s</lastUpdated>
  </versioning>
</metadata>
""" % (self.group, self.artifact, version, version, update_time)
            with open(metadata_path, 'w') as fd:
                fd.write(xml)


    def process(self, repo, tmp):
        logging.info('Processing artifact %s' % self.artifact)
        file = self.get_binary()

        # Check if one of the artifact's jar has the same mtime. If so no need to update
        mtimes = [y for x in [
                    [os.stat(os.path.join(root,f)).st_mtime for f in files if f.endswith('.jar')]
                        for root, dirs, files in os.walk(repo)
                        if '%s%s%s' % (os.path.sep, self.artifact, os.path.sep) in root]
                    for y in x]

        if not [mtime for mtime in mtimes if file.mtime == int(mtime)]:
            logging.debug('package mtime: %d, [%s]' % (file.mtime, ', '.join(['%f' % t for t in mtimes])))
            rpm_file = self.fetch_binary(file, tmp)
            if rpm_file is None:
                raise RuntimeError('Failed to download ' + file)

            # Find out the version
            m = re.match('.*-([^-]+)-[^-]+.[^.]+.rpm', rpm_file)
            if m is None:
                raise RuntimeError('Failed to get version of ' + rpm_file)
            version = m.group(1)

            # Extract the jar and pom
            (jar, version) = self.extract(rpm_file, tmp)

            # Install in the repository
            self.deploy(jar, version, repo, file.mtime)
        else:
            logging.info('Skipping artifact %s' % self.artifact)


class Configuration:
    def __init__(self, config_path, repo):
        data = {}
        if os.path.isfile(config_path):
            f = open(config_path, 'r')
            data = yaml.safe_load(f)
            f.close()
        self.url = data.get('url', "https://download.opensuse.org/repositories")
        self.repo = repo
        repositories = data.get('repositories', {})
        repos = {
            name: Repo(self.url, data['project'], data['repository'])
            for name, data
            in repositories.items()
        }

        self.artifacts = [Artifact(artifact, repos, data.get('group', 'suse')) for artifact in data.get('artifacts', [])]

def main():
    ret = 0
    parser = argparse.ArgumentParser(
        description="OBS to Maven repository synchronization tool",
        conflict_handler='resolve',
        formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument('config', help="Path to the YAML configuration file")
    parser.add_argument('out', help="Path to the output maven repository")

    parser.add_argument(
        "-d", "--debug",
        help="Show debug messages",
        action="store_const",
        dest="loglevel",
        const=logging.DEBUG,
        default=logging.INFO)

    args = parser.parse_args()

    logging.getLogger().setLevel(args.loglevel)
    logging.debug('Reading configuration')
    config = Configuration(args.config, args.out)
    tmp = tempfile.mkdtemp(prefix="obsmvn-")
    try:
        for artifact in config.artifacts:
            artifact.process(config.repo, tmp)
    except RuntimeError as e:
        logging.error(e)
        ret = 1
    shutil.rmtree(tmp)
    return ret

if __name__ == '__main__':
    sys.exit(main())
