#!/usr/bin/python
# Copyright (c) 2015 SUSE Linux GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from pprint import pformat
import os
import sys
import re
import logging
import cmdln
from FileListExtractor import FileListExtractor, Channel, RepoData
import pickle
import signal
import subprocess
import time
import yaml

try:
    from xml.etree import cElementTree as ET
except ImportError:
    import cElementTree as ET

USER_CONFIG_DIR = os.environ.get('XDG_CONFIG_HOME', os.path.expanduser('~/.config'))
SYS_CONFIG_DIRS = os.environ.get('XDG_CONFIG_DIRS', '/etc/xdg').split(':')
CONFIG_DIRS = [USER_CONFIG_DIR] + SYS_CONFIG_DIRS
CONFIG_DIRS = [os.path.join(d, 'rpmlint-backports-tool') for d in CONFIG_DIRS]
BLACKLIST = ['java-.*-ibm.*']
DBFILE = 'SLE-file-package-map.db'
module_re = re.compile(".*(sle-module|sle-manager)", flags=re.IGNORECASE)


def service_pack(name, SP=None):
    if name is None:
        return None
    if (not SP or module_re.match(name)):
        return name
    if SP == 'SP0':
        return name
    return re.sub(r'([_\-]12)([_\-:])', r'\1-%s\2' % SP, name)


class Tool(cmdln.Cmdln):
    def __init__(self, *args, **kwargs):
        cmdln.Cmdln.__init__(self, args, kwargs)

        self.ex = FileListExtractor()
        self.ex.save_channels_files = True
        self._channels = None
        self._config = None

    @property
    def config(self):
        if self._config is None:
            raise Exception("No config file loaded")
        return self._config

    @config.setter
    def config(self, data):
        self._config = data

    @property
    def channels(self):
        if self._channels is None:
            self._channels = self._get_channels()
        return self._channels

    def get_optparser(self):
        parser = cmdln.CmdlnOptionParser(self)
        parser.add_option("--dry", action="store_true", help="dry run")
        parser.add_option("--debug", action="store_true", help="debug output")
        parser.add_option("--verbose", action="store_true", help="verbose")
        parser.add_option("--servicepack", type="string", dest="servicepack", help="service pack", default=None)
        parser.add_option("--config", type="string", help="config file", default=None)
        parser.add_option("--workpath", type="string", help="work path", default='.')
        return parser

    @property
    def servicepacks(self):
        if self.options.servicepack:
            return [self.options.servicepack]
        return self.config.get('servicepacks', None)

    def _get_channels(self):
        channels = []
        chan_url = {}
        chan_url['product'] = self.config.get('product_channels_url', None)
        chan_url['update'] = self.config.get('update_channels_url', None)
        chan_url['local'] = self.config.get('local_channels_url', '{name}')
        channel_configs = self.config.get('channels', None)
        channel_schema = self.config.get('channel_schema', '_channel')

        if channel_configs is None:
            raise Exception("No channels in config")
        if chan_url['product'] is None:
            raise Exception("No product_channels_url in config")
        if chan_url['update'] is None:
            raise Exception("No update_channels_url in config")

        for ID, config in channel_configs.items():
            archs = config.get('archs', ['x86_64'])
            for channeltype in ['product', 'update', 'local']:
                name = config.get(channeltype, None)
                if name is None:
                    continue
                servicepacks = self.servicepacks or ['SP0']
                if module_re.match(name) or channeltype == 'local':
                    servicepacks = ['SP0']
                for servicepack in servicepacks:
                    if '@' in servicepack:
                        servicepack, snapshot = servicepack.split('@', 1)
                    else:
                        snapshot = ''
                    for arch in archs:
                        url = chan_url[channeltype].format(name=name,
                                                           arch=arch,
                                                           snapshot=snapshot
                                                           )
                        url = service_pack(url, SP=servicepack)
                        sp_name = service_pack(os.path.split(name)[1], SP=servicepack)
                        sp_name = '%s-%s' % (sp_name, arch)
                        if channel_schema == 'repomd':
                            newchannel = RepoData(sp_name, url, channeltype, logger=self.logger)
                        else:
                            newchannel = Channel(sp_name, url, channeltype, logger=self.logger)
                        newchannel.id = ID
                        newchannel.exclude = config.get('exclude', False)
                        newchannel.whitelist = config.get('whitelist', None)
                        channels.append(newchannel)

        return channels

    def filter_channels(self, channel_names):
        if channel_names:
            names = set(channel_names)
            return [c for c in self.channels if c.name in names]

        return self.channels

    def _load_config(self, configfile):
        def _load(path):
            self.logger.debug("Loading %s" % path)
            try:
                with open(path) as f:
                    self.config = yaml.load(f)
                self.logger.info("Loaded config file: %s" % path)
                return True
            except Exception as e:
                self.logger.debug("Failed to load %s. %s" % (path, e))
                return False

        # First try path exactly as specified
        if os.path.isfile(os.path.abspath(configfile)) and _load(configfile):
            return True

        # Search for config file in standard locations
        # Load first match and return
        filenames = [configfile, configfile + '.conf', configfile + '.config']
        for path in CONFIG_DIRS:
            for fname in filenames:
                file_path = os.path.abspath(os.path.join(path, fname))
                self.logger.debug("Checking for %s" % file_path)
                if not os.path.isfile(file_path):
                    continue
                if _load(file_path):
                    return True

        raise Exception('Failed to find and load config: %s' % configfile)

    def postoptparse(self):
        logging.basicConfig()
        self.logger = logging.getLogger(self.optparser.prog)
        if self.options.debug:
            self.logger.setLevel(logging.DEBUG)
            self.ex.debug = True
        elif self.options.verbose:
            self.logger.setLevel(logging.INFO)

        if self.options.config:
            self._load_config(self.options.config)

        if self.options.workpath:
            try:
                os.chdir(os.path.abspath(self.options.workpath))
            except Exception as e:
                raise Exception("Failed to switch to workpath %s: %s" % (self.options.workpath, e))
        self.ex.logger = self.logger

    def do_list(self, subcmd, opts, *channel_names):
        """${cmd_name}: list channels

        ${cmd_usage}
        ${cmd_option_list}
        """

        channels = self.filter_channels(channel_names)
        for c in sorted(channels):
            print("%s\n  %s\n  %s" % (c.name, c.url, c.channel_filename()))

    @cmdln.option("-f", "--force", action="store_true", help="force something")
    def do_pickle(self, subcmd, opts, *channel_names):
        """${cmd_name}: generate pickle for channels

        ${cmd_usage}
        ${cmd_option_list}
        """

        channels = self.filter_channels(channel_names)
        self._pickle(sorted(channels), force=opts.force)

    def _pickle(self, channels, force=False):
        self.ex.set_blacklist(BLACKLIST)  # What is this for?
        self.ex.set_file_blacklist(self.config.get('file_blacklist', None))

        def process(channel):
            self.logger.debug("processing %s", channel.name)
            fn = channel.pickle_filename()
            olddata = None
            if os.path.exists(fn):
                if channel.type == 'product' and not force:
                    return False
                with open(fn, 'rb') as f:
                    olddata = pickle.load(f)
            data = self.ex.readFileLists([channel])
            if olddata is not None:
                missing = olddata['pkgnames'] - data['pkgnames']
                new = data['pkgnames'] - olddata['pkgnames']
                if new:
                    self.logger.info("%s - new: %s", channel.name, ', '.join(new))
                if missing:
                    self.logger.warning("%s - packages vanished: %s", channel.name, ', '.join(missing))
                if not new:
                    self.logger.info("%s - no change" % channel.name)
                    return False
                data = self.ex.merge(olddata, data)
            with open(fn, 'wb') as f:
                pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
            with open(channel.dump_filename(), 'wb') as f:
                f.write(pformat(data))
            return True

        changed = False
        for channel in channels:
            if process(channel):
                changed = True

        return changed

    @cmdln.option("-o", "--output", dest='filename', metavar='FILE', help="output to FILE")
    @cmdln.option("--shelve", action="store_true", help="save as shelve file")
    def do_merge(self, subcmd, opts, *files):
        """${cmd_name}: merge .p files after call to pickle

        ${cmd_usage}
        ${cmd_option_list}
        """

        if not opts.filename:
            raise Exception('filename missing')

        return self._merge(opts.filename, files, shelve=opts.shelve)

    def _merge(self, filename, files, shelve=False):

        if not files:
            files = [i.pickle_filename() for i in sorted(self.channels) if not i.exclude]

        res = {
            'filenames': dict(),
            'pkgnames': set(),
        }

        for fn in files:
            with open(fn, 'rb') as f:
                self.logger.debug("merging %s" % fn)
                data = pickle.load(f)
                res = self.ex.merge(res, data)

        filename = filename
        tmpfn = filename + '.new'
        if shelve:
            import shelve
            # Python2 uses bsddb by default, Python3 does not support bsddb
            # Explictly use gdbm for compatiblity with 2 and 3 (i.e. SLES 12
            # and SLES 15). Also ensure same protocol (2) is used.
            try:
                mod = __import__("dumbdbm")
            except Exception as e:
                self.logger.debug(e)
                self.logger.debug('trying dbm.dumb')
                mod = __import__("dbm.dumb")
            d = shelve.Shelf(mod.open(tmpfn, 'n'), protocol=2)
            d.update(res)
            d.close()
        else:
            with open(tmpfn, 'wb') as f:
                pickle.dump(res, f, pickle.HIGHEST_PROTOCOL)

        # dbm dumb creates 3 files filename.dat and filename.dir filename.bak
        # We need the .dir and .dat files
        for ext in ['.dat', '.dir']:
            os.rename(tmpfn + ext, filename + ext)

        os.remove(tmpfn + '.bak')

    @cmdln.option("-f", "--force", action="store_true", help="force something")
    @cmdln.option('-c', "--check", metavar="check", action="append", help="which file to check against the rest")
    def do_check_bsk(self, subcmd, opts, *channel_names):
        """${cmd_name}: check bsk consistency

        checks if files in BSK are actually just subpackages of stuff that
        already is in other products. Also prints duplicates and what would be
        left if the duplicates and supackages were removed.

        Only operates locally. Needs previous run of pickle

        ${cmd_usage}
        ${cmd_option_list}
        """
        pkgs = dict()       # srcpkg -> set(binpkgs)
        checkpkgs = dict()  # srcpkg -> set(binpkgs)

        b2chan = dict()     # binpkg -> set(channel)

        if not opts.check:
            opts.check = [c for c in self.channels if c.id == 'BSK']

        channels = self.filter_channels(channel_names)
        channels = [c for c in channels if c.name != 'sle_exceptions']

        def parse(dst, channels):
            for channel in channels:
                fn = channel.channel_filename()
                with open(fn, 'rb') as f:
                    root = ET.parse(f).getroot()
                    for binaries in root.findall('binaries'):
                        for node in binaries.findall('binary'):
                            name = node.attrib['name']
                            package = node.attrib['package']
                            if package.startswith('_product:'):
                                continue
                            if name.endswith('-debuginfo') or name.endswith('-debugsource'):
                                continue
                            dst.setdefault(package, set()).add(name)
                            b2chan.setdefault(name, set()).add(channel.name)

        parse(pkgs, channels)
        self._pickle(opts.check)
        parse(checkpkgs, opts.check)

        bsk_all = set(checkpkgs.keys())
        for p in sorted(checkpkgs.keys()):
            if p in pkgs:
                missing = checkpkgs[p] - pkgs[p]
                overlap = pkgs[p] & checkpkgs[p]
                if missing:
                    bsk_all.remove(p)
                    print("separate subpackage %s: %s" % (p, ', '.join(sorted(missing))))
                if overlap:
                    if p in bsk_all:
                        bsk_all.remove(p)
                    for b in overlap:
                        print("duplicate %s: %s" % (b, ', '.join(sorted(b2chan[b]))))

        for p in sorted(bsk_all):
            print("left %s: %s" % (p, ','.join(sorted(checkpkgs[p]))))

    def do_check_exported(self, subcmd, opts, *channel_names):
        """${cmd_name}: check OBS exported packages

        check if binary rpms are exported in obs

        ${cmd_usage}
        ${cmd_option_list}
        """

        exported = set()

        import osc.conf

        self.ex._init_osc()

        apiurl = osc.conf.config['apiurl']
        apipath = service_pack('build/openSUSE.org:SUSE:SLE-12:GA/standard/x86_64/_repository',
                               self.options.servicepack)
        u = osc.core.makeurl(apiurl, apipath.split('/'), ['view=binaryversions'])
        r = osc.core.http_GET(u)
        root = ET.parse(r).getroot()
        for node in root.findall('binary'):
            name = node.attrib['name']
            name = name[:-len('.rpm')]
            if name.endswith('-debuginfo') or name.endswith('-debugsource'):
                continue
            if name.endswith('-debuginfo-32bit') or name.endswith('-debugsource-32bit'):
                continue
            exported.add(name)

        channels = self.filter_channels(channel_names)

        b2chan = dict()  # binpkg -> set(channel)

        def parse(channels):
            blacklist = re.compile('(' + '|'.join(BLACKLIST + ['update-test-.*']) + ')')
            for channel in channels:
                fn = channel.channel_filename()
                with open(fn, 'rb') as f:
                    root = ET.parse(f).getroot()
                    for binaries in root.findall('binaries'):
                        for node in binaries.findall('binary'):
                            name = node.attrib['name']
                            package = node.attrib['package']
                            if package.startswith('_product:'):
                                continue
                            if name.endswith('-debuginfo') or name.endswith('-debugsource'):
                                continue
                            if name.endswith('-debuginfo-32bit') or name.endswith('-debugsource-32bit'):
                                continue
                            if blacklist.match(name):
                                continue
                            b2chan.setdefault(name, set()).add(channel.name)

        parse(channels)
        needed = set(b2chan.keys())

        for p in sorted(needed - exported):
            print("missing %s: %s" % (p, ', '.join(b2chan[p])))
        for p in sorted(exported - needed):
            print("extra %s" % p)

    def do_update_obs(self, subcmd, opts):
        """${cmd_name}: upload dbfile to rpmlint-backports-data package
        """
        self._upload(DBFILE)

    def _upload(self, dbfile):
        import osc.core
        if not os.path.exists(dbfile):
            raise Exception("Can't find %s. Run pickle and merge" % dbfile)
        self.logger.info("uploading")
        self.ex._init_osc()
        apiurl = self.config.get('rpmling_apiurl', 'https://api.opensuse.org')
        prj = self.config.get('rpmlint_prj', None)
        if prj is None:
            raise Exception("No project defined for obs update!")
        pkg = 'rpmlint-backports-data'
        u = osc.core.makeurl(apiurl, ['source', prj, pkg, dbfile], {})
        r = osc.core.http_PUT(u, file=dbfile)
        self.logger.info(r.read())

    @cmdln.option('-n', '--interval', metavar="minutes", type="int", help="periodic interval in minutes")
    @cmdln.option('--git-commit', action="store_true", help="commit changes to git")
    def do_run_bot(self, subcmd, opts):
        """${cmd_name}: update files, upload

        ${cmd_usage}
        ${cmd_option_list}
        """

        class ExTimeout(Exception):
            """raised on timeout"""

        if opts.interval:
            def alarm_called(nr, frame):
                raise ExTimeout()
            signal.signal(signal.SIGALRM, alarm_called)

        while True:
            try:
                channels = sorted(self.channels)
                if self._pickle(channels):
                    if opts.git_commit:
                        try:
                            subprocess.check_output(['git', 'init'])
                            for c in channels:
                                subprocess.check_output(['git', 'add', c.dump_filename()])
                            subprocess.check_output(['git', 'commit', '-m', 'update'])
                        except subprocess.CalledProcessError as e:
                            self.logger.warning("### git ERROR: %s" % e)
                    self.logger.info("merging")
                    self._merge(DBFILE, None, shelve=True)
                    self._upload(DBFILE)
            except Exception as e:
                self.logger.error("### ERROR: %s" % e)

            if opts.interval:
                print("sleeping %d minutes. Press enter to check now ..." % opts.interval)
                signal.alarm(opts.interval * 60)
                try:
                    raw_input()
                except ExTimeout:
                    pass
                except EOFError:
                    # no tty available, let's sleep then
                    time.sleep(opts.interval)
                signal.alarm(0)
                continue
            break


if __name__ == "__main__":
    app = Tool()
    try:
        sys.exit(app.main())
    except Exception as e:
        app.logger.error('%s. exiting...' % e, exc_info=True)
        sys.exit(1)

# vim: sw=4 et
