#!/usr/bin/python3

# Copyright (c) 2017, SUSE LLC, All rights reserved.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3.0 of the License, or (at your option) any later version.
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public
# License along with this library.

"""This script obtains information from the configured region server in the
   cloud environment and uses the information to register the guest with
   the SMT server based on the information provided by the region server.

   The configuration is in INI format and is located in
   /etc/regionserverclnt.cfg

   Logic:
   1.) Check if we are in the same region
       + Comparing information received from the region server and the
         cached data
   2.) Check if already registered
   3.) Register"""

import configparser
import base64
import copy
import glob
import json
import logging
import os
import pickle
import re
import requests
import stat
import subprocess
import sys
import time
import urllib.parse
import uuid

import cloudregister.registerutils as utils

from cloudregister import smt
from lxml import etree
from M2Crypto import X509

error_exp = re.compile('^error', re.IGNORECASE)


# ----------------------------------------------------------------------------
def get_products():
    products = []
    try:
        cmd = subprocess.Popen(
            ["zypper", "--no-remote", "-x", "products"], stdout=subprocess.PIPE
        )
        product_xml = cmd.communicate()
    except:
        errMsg = 'Could not get product list %s' % cmd[1]
        logging.error(errMsg)
        return

    # Detrmine the base product
    baseProdSet = '/etc/products.d/baseproduct'
    baseprodName = None
    if os.path.islink(baseProdSet):
        baseprod = os.path.realpath(baseProdSet)
        baseprodName = baseprod.split(os.sep)[-1].split('.')[0]
    else:
        errMsg = 'No baseproduct installed system cannot be registerd'
        logging.error(errMsg)
        return

    product_tree = etree.fromstring(product_xml[0].decode())
    for child in product_tree.find("product-list"):
        name = child.attrib['name']
        if name == baseprodName:
            continue
        vers = child.attrib['version']
        arch = child.attrib['arch']
        prod = name + "/" + vers + "/" + arch
        if prod not in products:
            products.append(prod)

    return products


# ----------------------------------------------------------------------------
# Support custom config file with -f command line option
config_file = None

if '-f' in sys.argv:
    idx = sys.argv.index('-f')
    config_file = sys.argv[idx+1]

force_new_registration = False
if '--force-new' in sys.argv:
    force_new_registration = True

if '--delay' in sys.argv:
    arg_index = sys.argv.index('--delay')
    delay_time = sys.argv[arg_index + 1]
    time.sleep(int(delay_time))

user_smt_ip = None
if '--smt-ip' in sys.argv:
    arg_index = sys.argv.index('--smt-ip')
    user_smt_ip = sys.argv[arg_index + 1]
    # Validation has to wait for Python 3 implementation

user_smt_fqdn = None
if '--smt-fqdn' in sys.argv:
    arg_index = sys.argv.index('--smt-fqdn')
    user_smt_fqdn = sys.argv[arg_index + 1]

user_smt_fp = None
if '--smt-fp' in sys.argv:
    arg_index = sys.argv.index('--smt-fp')
    user_smt_fp = sys.argv[arg_index + 1]

if user_smt_ip or user_smt_fqdn or user_smt_fp:
    if not (user_smt_ip and user_smt_fqdn and user_smt_fp):
        msg = '--smt-ip, --smt-fqdn, and --smt-fp must be used together'
        print(msg, file=sys.stderr)

cfg = utils.get_config(config_file)
utils.start_logging()

if not os.path.isdir(utils.REGISTRATION_DATA_DIR):
    os.makedirs(utils.REGISTRATION_DATA_DIR)

if force_new_registration:
    logging.info('Forced new registration')

if user_smt_ip:
    msg = 'Using user specified SMT server:\n'
    msg += '\n\t"IP:%s"' % user_smt_ip
    msg += '\n\t"FQDN:%s"' % user_smt_fqdn
    msg += '\n\t"Fingerprint:%s"' % user_smt_fp
    logging.info(msg)

cached_smt_servers = utils.get_available_smt_servers()
# Forced registration or user specified SMT, clear existing registration
# data
if (force_new_registration and cached_smt_servers) or user_smt_ip:
    if utils.is_zypper_running():
        msg = 'zypper is running: Registration with the update '
        msg += 'infrastructure is only possible if zypper is not running.\n'
        msg += 'Please re-run the force registration process after zypper '
        msg += 'has completed'
        print(msg)
        sys.exit(1)
    utils.remove_registration_data()
    utils.clean_smt_cache()
    cached_smt_servers = []

# Proxy setup
proxies = None
proxy = utils.set_proxy()
if proxy:
    http_proxy = os.environ.get('http_proxy')
    https_proxy = os.environ.get('https_proxy')
    proxies = {'http_proxy': http_proxy,
               'https_proxy': https_proxy}
    logging.info('Using proxy settings: %s' % proxies)

if user_smt_ip:
    smt_xml = '<regionSMTdata><smtInfo '
    smt_xml += 'fingerprint="%s" ' % user_smt_fp
    smt_xml += 'SMTserverIP="%s" ' % user_smt_ip
    smt_xml += 'SMTserverName="%s"' % user_smt_fqdn
    smt_xml += '/></regionSMTdata>'
    region_smt_data = etree.fromstring(smt_xml)
else:
    # Check if we are in the same region
    # This implies that at least one of the cached servers is also in the
    # data received in the SMT data server
    region_smt_data = utils.fetch_smt_data(cfg, proxies)

registration_smt = utils.get_current_smt()

region_smt_servers = {'cached': [], 'new': []}
# Compare the smt information received from the SMT data server with
# the cached data
for child in region_smt_data:
    smt_server = smt.SMT(child)
    for cached_smt in cached_smt_servers:
        if cached_smt == smt_server:
            cached_smt_servers.remove(cached_smt)
            region_smt_servers['cached'].append(cached_smt)
            break
    else:
        region_smt_servers['new'].append(smt_server)

# If we have extra SMT data check if the extra data clean up the entire cache
if cached_smt_servers:
    logging.info('Have extra cached SMT data, clearing cache')
    for smt in cached_smt_servers:
        if registration_smt and smt.is_equivalent(registration_smt):
            msg = 'Extra cached server is current registration target, '
            msg += 'cleaning up registration'
            logging.info(msg)
            utils.remove_registration_data()
            registration_smt = None
            break
    # Clean the cache and re-write all the cache data later
    utils.clean_smt_cache()
    cached_smt_servers = []

if region_smt_servers['new']:
    # Create a new cache
    smt_count = len(region_smt_servers['cached']) + 1
    for smt_server in region_smt_servers['new']:
        store_file_name = (
            utils.REGISTRATION_DATA_DIR +
            utils.AVAILABLE_SMT_SERVER_DATA_FILE_NAME % smt_count
        )
        utils.store_smt_data(store_file_name, smt_server)
        smt_count += 1

# We no longer need to differentiate between new and existing SMT servers
region_smt_servers = region_smt_servers['cached'] + region_smt_servers['new']

# Check if the target SMT for the registration is alive or if we can
# find a server that is alive in this region
if registration_smt:
    registration_smt_cache_file_name = (
            utils.REGISTRATION_DATA_DIR +
            utils.REGISTERED_SMT_SERVER_DATA_FILE_NAME
        )
    alive = registration_smt.is_responsive()
    if alive:
        msg = 'Instance is registered, and SMT server is reachable, '
        msg += 'nothing to do'
        # The cache data may have been cleared, write if necessary
        if not os.path.exists(registration_smt_cache_file_name):
            utils.store_smt_data(
                registration_smt_cache_file_name,
                registration_smt
            )
        logging.info(msg)
        sys.exit(0)
    else:
        # The configured server is not resposive, lets check if we can
        # find another server
        new_target = utils.find_equivalent_smt_server(
            registration_smt,
            region_smt_servers
        )
        if new_target:
            msg = 'Configured SMT unresponsive, switching to equivalent '
            msg += 'SMT server with ip %s' % new_target.get_ip()
            utils.replace_hosts_entry(registration_smt, new_target)
            utils.store_smt_data(
                registration_smt_cache_file_name,
                registration_smt
            )
        else:
            msg = 'Configured SMT unresponsive, could not find '
            msg += 'a replacement SMT server in this region. '
            msg += 'Possible network configuration issue'
            logging.error(msg)
            sys.exit(1)

# Figure out which server is responsive and use it as registration target
registration_target = None
tested_smt_servers = []
for smt in region_smt_servers:
    smt_ip = smt.get_ip()
    tested_smt_servers.append(smt_ip)
    alive = smt.is_responsive()
    if alive:
        registration_target = smt
        # Use the first server that responds
        break

if not registration_target:
    logging.error('No response from: %s' % tested_smt_servers)
    sys.exit(1)

# Add the target SMT server to the hosts file
utils.add_hosts_entry(registration_target)

# Create location to store data if it does not exist
if not os.path.exists(utils.REGISTRATION_DATA_DIR):
    os.system('mkdir -p %s' % utils.REGISTRATION_DATA_DIR)

# Write the data of the current target server
utils.set_as_current_smt(registration_target)

# Check if we need to send along any instance data
instance_data_filepath = utils.REGISTRATION_DATA_DIR + str(uuid.uuid4())
if cfg.has_section('instance'):
    if cfg.has_option('instance', 'dataProvider'):
        instance_data_cmd = cfg.get('instance', 'dataProvider')
        cmd = instance_data_cmd.split()[0]
        if cmd != 'none':
            if cmd[0] != '/':
                try:
                    p = subprocess.Popen(
                        ['which %s' % cmd],
                        stdout=subprocess.PIPE,
                        stderr=subprocess.PIPE,
                        close_fds=True
                    )
                except:
                    errMsg = 'Could not find configured dataProvider: %s' % cmd
                    logging.error(errMsg)
            if os.access(cmd, os.X_OK):
                os.system("%s > %s" % (instance_data_cmd,
                                       instance_data_filepath))
            else:
                msg = 'Configured dataProvider "%s" is not executable' % cmd
                logging.error(msg)

register11 = '/usr/lib/suseRegister/bin/clientSetup4SMT.sh'
register12 = '/usr/sbin/SUSEConnect'
if (os.path.exists(register11) and os.access(register11, os.X_OK)):
    base_registered = False
    failed_smts = []
    while not base_registered:
        cmd = register11
        cmd += ' --host %s ' % registration_target.get_FQDN()
        cmd += '--fingerprint %s ' % registration_target.get_fingerprint()
        cmd += '--yes '

        if os.path.exists(instance_data_filepath):
            cmd += '--regdata %s' % instance_data_filepath

        if force_new_registration:
            cmd += ' > /dev/null 2>&1'

        logging.info('Registration: %s' % cmd)
        res = os.system(cmd)

        if res:
            failed_smts.append(registration_target.get_ip())
            if len(failed_smts) == len(region_smt_servers):
                logging.error('Registration failed')
                sys.exit(1)
            for smt in region_smt_servers:
                if (
                        smt.get_ip() != registration_target.get_ip() and
                        smt.get_ip() not in failed_smts
                ):
                    error_msg = 'Registration with %s failed. Trying %s'
                    logging.error(
                        error_msg % (
                            registration_target.get_ip(),
                            smt.get_ip()
                        )
                    )
                    utils.remove_registration_data()
                    utils.add_hosts_entry(smt)
                    registration_target = smt
                    break
        else:
            base_registered = True

    # registration was successful, let's check if the repos do
    # really exist and if not complete the registration by calling
    # suse_register with the restore-repos option
    reposExist = utils.has_repos(registration_target.get_FQDN())
    if not reposExist:
        cmd = "suse_register --restore-repos"
        res = os.system(cmd)
        if res:
            logging.info('Repositories were not restored')

    utils.enable_repository('SLE11-Public-Cloud-Module')

elif (os.path.exists(register12) and os.access(register12, os.X_OK)):
    # get product list
    products = get_products()
    if products is None:
        sys.exit(1)

    if not utils.import_smt_cert(registration_target):
        logging.error('SMT certificate import failed')
        sys.exit(1)

    # Register the base product first
    base_registered = False
    failed_smts = []
    while not base_registered:
        cmd = [register12, '--url']
        cmd.append('https://%s' % registration_target.get_FQDN())
        if os.path.exists(instance_data_filepath):
            cmd.append('--instance-data')
            cmd.append(instance_data_filepath)
        p = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE)
        res = p.communicate()
        for item in res:
            entry = item.decode()
            if error_exp.match(entry):
                failed_smts.append(registration_target.get_ip())
                if len(failed_smts) == len(region_smt_servers):
                    logging.error('Baseproduct registration failed')
                    logging.error('\t%s' % entry)
                    sys.exit(1)
                for smt in region_smt_servers:
                    if (
                            smt.get_ip() != registration_target.get_ip() and
                            smt.get_ip() not in failed_smts
                    ):
                        error_msg = 'Registration with %s failed. Trying %s'
                        logging.error(
                            error_msg % (
                                registration_target.get_ip(),
                                smt.get_ip()
                            )
                        )
                        utils.remove_registration_data()
                        utils.add_hosts_entry(smt)
                        registration_target = smt
                        break
                break
        else:
            base_registered = True

    for product in products:
        cmd = [
            register12,
            '--url',
            'https://%s' % registration_target.get_FQDN(),
            '--product',
            product]
        if os.path.exists(instance_data_filepath):
            cmd.append('--instance-data')
            cmd.append(instance_data_filepath)

        logging.info('Registration: %s' % ' '.join(cmd))
        p = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE)
        res = p.communicate()

        for item in res:
            entry = item.decode()
            if 'Error:' in entry:
                logging.error('\tRegistration failed: %s' % entry)
else:
    logging.error('No registration executable found')
    sys.exit(1)

if os.path.exists(instance_data_filepath):
    os.unlink(instance_data_filepath)

# Enable Nvidia repo if repo(s) are configured and destination can be reached
if utils.has_nvidia_support():
    nvidia_repo_names = utils.find_repos('nvidia')
    for repo_name in nvidia_repo_names:
        url = urllib.parse.urlparse(utils.get_repo_url(repo_name))
        cmd = ['ping', '-c', '2', url.hostname]
        if utils.exec_subprocess(cmd):
            msg = 'Cannot reach host: "%s", will not enable repo "%s"'
            logging.info(msg % (url.hostname, repo_name))
        else:
            utils.enable_repository(repo_name)
