#!/usr/bin/env python
# Copyright (C) 2023-2024  Nexedi SA and Contributors.
#
# This program is free software: you can Use, Study, Modify and Redistribute
# it under the terms of the GNU General Public License version 3, or (at your
# option) any later version, as published by the Free Software Foundation.
#
# You can also Link and Combine this program with other software covered by
# the terms of any of the Free Software licenses or any of the Open Source
# Initiative approved licenses and Convey the resulting work. Corresponding
# source of such a combination shall include the source code for all other
# software used.
#
# This program is distributed WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
#
# See COPYING file for full licensing terms.
# See https://www.nexedi.com/licensing for rationale and options.
"""Program tapsplit brings tap interface into state with several children
interfaces each covering part of original interface address space.

Usage: tapsplit <interface> <nchildren>
"""

# TODO Relying on tapsplit should be removed once SlapOS is improved to provide
#      several TAP interfaces to instances. See discussion at
#      https://lab.nexedi.com/nexedi/slapos/merge_requests/1471#note_194356
#      for details.

import netifaces
import netaddr
from socket import AF_INET6
from math import log2, ceil

import sys
import subprocess
import json


# LinkDB represents snapshot of state of all network interfaces.
class LinkDB:
    def __init__(db):
        db.linkv = ip('link', 'show')

    # ifget returns information about interface with specified name.
    def ifget(db, ifname):
        for link in db.linkv:
            if link['ifname'] == ifname:
                return link
        raise KeyError('interface %r not found' % ifname)

def main():
    tap = sys.argv[1]
    n   = int(sys.argv[2])
    assert n >= 0, n

    # determine tap's network address and owner
    ldb = LinkDB()
    _ = ldb.ifget(tap)
    owner = _['linkinfo']['info_data']['user']
    net   = ifnet6(tap)

    print('%s: split %s by %d' % (tap, net, n))

    # do the split
    # with leaving first range for the original tap
    subtap_set = set()
    for i, subnet in enumerate(netsplit(net, 1+n)):
        if i == 0:
            print('preserve         %s' % subnet)
            continue    # leave this range for original tap

        subtap = '%s-%d' % (tap, i)
        subtap_set.add(subtap)
        print('-> %s   %s' % (subtap, subnet))
        def note(msg):
            print(' # %s: %s' % (subtap, msg))

        # create subtap
        try:
            link = ldb.ifget(subtap)
        except KeyError:
            run('ip', 'tuntap', 'add', 'dev', subtap, 'mode', 'tap', 'user', owner)
            link = ip('link', 'show', 'dev', subtap)[0]
        else:
            note('already exists')

        # set it up
        if 'UP' not in link['flags']:
            run('ip', 'link',   'set', subtap, 'up')
        else:
            note('already up')

        # add subnet address
        addrv = []
        for _ in ip('-6', 'addr', 'show', 'dev', subtap):
            addrv.extend(_['addr_info'])
        for addr in addrv:
            _ = netaddr.IPNetwork('%s/%s' % (addr['local'], addr['prefixlen']))
            if _ == subnet and addr['noprefixroute']:
                note('already has %s addr' % str(subnet))
                break
        else:
            run('ip', 'addr',   'add', str(subnet), 'dev', subtap, 'noprefixroute')

        # add /128 route to subnet::1
        rtv = ip('-6', 'route', 'show', 'dev', subtap)
        for rt in rtv:
            if rt['dst'] == str(subnet[1])  and  'gateway' not in rt:
                note('already has %s route' % str(subnet[1]))
                break
        else:
            run('ip', 'route',  'add', str(subnet[1]), 'dev', subtap)

        # add route to subnet via subnet::1
        for rt in rtv:
            if rt['dst'] == str(subnet)  and  rt.get('gateway') == str(subnet[1]):
                note('already has %s route' % str(subnet))
                break
        else:
            run('ip', 'route',  'add', str(subnet), 'dev', subtap, 'via', str(subnet[1]))

    # remove other existing children
    for ifname in netifaces.interfaces():
        if ifname.startswith('%s-' % tap) and (ifname not in subtap_set):
            print('-> del %s' % ifname)
            run('ip', 'link', 'del', ifname)


# netsplit splits network into n subnetworks.
def netsplit(net, n): # -> []subnet
    # see how much prefix bits we need to take to be able to divide by n
    ptake = ceil(log2(n))
    return list( net.subnet(net.prefixlen + ptake) )[:n]

# ifnet6 returns IPv6 network address associated with given interface.
def ifnet6(ifname):
    addr  = None
    net   = None
    prefixlen = None
    for iaddr in netifaces.ifaddresses(ifname)[AF_INET6]:
        a = iaddr['addr']
        if '%' in a: # link-local
            a = a.split('%')[0]
            a = netaddr.IPAddress(a)
            assert a.is_link_local(), a
            continue

        if addr is not None:
            raise RuntimeError('%s: multiple addresses:  %s and %s' % (ifname, addr, a))

        addr = netaddr.IPAddress(a)
        netmask, plen = iaddr['netmask'].split('/')
        prefixlen = int(plen)
        net = netaddr.IPNetwork('%s/%d' % (a, prefixlen))

    if addr is None:
        raise RuntimeError('%s: no non link-local addresses' % ifname)

    # normalize network
    # ex 2401:5180:0:66:a7ff:ffff:ffff:ffff/71  ->  2401:5180:0:66:a600::/71
    net = net.cidr
    return net

# run executes `*argv` as action.
def run(*argv):
    print(' # %s' % ' '.join(argv))
    subprocess.check_call(argv)

# ip returns decoded output of `ip -details *argv`
def ip(*argv):
    _ = subprocess.check_output(['ip', '-json', '-details'] + list(argv))
    return json.loads(_)


if __name__ == '__main__':
    main()
