#!/usr/bin/python3
# vim: set ts=4 sw=4 tw=0 et pm=:
import time
import getopt
import sys
import re
import threading
import os.path
import argparse

import iridium.iridium_extractor_flowgraph

queue_len_max = 0

out_count = 0
in_count = 0
drop_count = 0
in_ok_count = 0
out_ok_count = 0

in_count_total = 0
out_count_total = 0
drop_count_total = 0
in_ok_count_total = 0
out_ok_count_total = 0

sample_count_total = 0

last_print = 0
t0 = time.time()

sample_rate = None
offline = False

sample_formats = {
    'float': 'cf32_le',
    'fc32': 'cf32_le',
    'cfile': 'cf32_le',
    'sc16': 'ci16_le',
    'hackrf': 'ci8',
    'sc8': 'ci8',
    'rtl': 'cu8',
}


def print_stats(tb):
    global last_print, queue_len_max, out_count, in_count
    global drop_count, drop_count_total, in_ok_count, out_ok_count, sample_count_total
    global in_ok_count_total, out_ok_count_total, out_count_total, in_count_total, t0
    while True:

        queue_len = tb.get_queue_size()
        queue_len_max = tb.get_max_queue_size()

        in_count = tb.get_n_detected_bursts() - in_count_total
        out_count = tb.get_n_handled_bursts() - out_count_total
        in_ok_count = tb.get_n_access_ok_bursts() - in_ok_count_total
        out_ok_count = tb.get_n_access_ok_sub_bursts() - out_ok_count_total
        drop_count = tb.get_n_dropped_bursts() - drop_count_total
        sample_count = tb.get_sample_count() - sample_count_total

        dt = time.time() - last_print
        in_rate = in_count / dt
        in_count_total += in_count
        in_rate_avg = in_count_total / (time.time() - t0)
        out_rate = out_count / dt
        drop_rate = drop_count / dt

        sample_rate_ratio = float(sample_count) / float(sample_rate) / float(dt)
        sample_count_total += sample_count
        sample_rate_ratio_avg = float(sample_count_total) / float(sample_rate) / (time.time() - t0)

        if in_count != 0:
            in_ok_ratio = in_ok_count / float(in_count)
            out_ok_ratio = out_ok_count / float(in_count)
        else:
            in_ok_ratio = 0
            out_ok_ratio = 0

        ok_rate = out_ok_count / dt
        drop_count_total += drop_count
        in_ok_count_total += in_ok_count
        out_ok_count_total += out_ok_count
        out_count_total += out_count

        if in_count_total != 0:
            ok_ratio_total = out_ok_count_total / float(in_count_total)
        else:
            ok_ratio_total = 0

        ok_rate_avg = out_ok_count_total / (time.time() - t0)

        stats = ""
        stats += "%d" % time.time()
        if offline:
            stats += " | srr: %5.1f%%" % (sample_rate_ratio * 100)
        else:
            stats += " | i: %3d/s" % in_rate
        stats += " | i_avg: %3d/s" % in_rate_avg
        stats += " | q_max: %4d" % queue_len_max
        stats += " | i_ok: %3d%%" % (in_ok_ratio * 100)
        stats += " | o: %4d/s" % out_rate
        stats += " | ok: %3d%%" % (out_ok_ratio * 100)
        stats += " | ok: %3d/s" % ok_rate
        stats += " | ok_avg: %3d%%" % (ok_ratio_total * 100)
        stats += " | ok: %10d" % out_ok_count_total
        stats += " | ok_avg: %3d/s" % ok_rate_avg
        stats += " | d: %d" % drop_count_total
        print(stats, file=sys.stderr)
        if not offline and sample_rate_ratio < 0.98 and (last_print - t0) > 3:
            print("WARNING: your SDR seems to be losing samples. ~%dk samples lost (%.0f%%)" % (sample_rate * (1 - sample_rate_ratio) / 1000, (1 - sample_rate_ratio) * 100), file=sys.stderr)
        sys.stderr.flush()

        queue_len_max = 0
        in_count = 0
        last_print = time.time()

        time.sleep(1)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Receive and demodulate Iridium transmissions.',
                                     epilog='For more detailed usage information see https://github.com/muccc/gr-iridium/')
    parser.add_argument('-b', '--max-bursts', dest='max_bursts', type=int, default=0,
                        help='Maximum allowed number of active bursts. Default: automatic')
    parser.add_argument('-w', '--burst-width', dest='burst_width', type=int, default=40000,
                        help='Width in Hz of the burst filter. Default: 40 kHz.')
    parser.add_argument('-c', '--center', dest='center', type=int,
                        help='Center frequency for the source or the file in Hz.')
    parser.add_argument('-d', '--db', dest='threshold', type=float, default=18,
                        help='Burst detection threshold in dB. Default: 18 dB.')
    parser.add_argument('--debug-id', dest='debug_id', type=int,
                        help='Write debug data for a certain frame ID. Last digit must be 0. See documentation for details.')
    parser.add_argument('-D', '--decimation', dest='decimation', type=int,
                        help='Number of channel to split the signal into. Must be even or 1. Default: 1.')
    parser.add_argument('--file-info', dest='file_info', action="store",
                        help='Manually set the file info field in the output data.')
    parser.add_argument('-f', '--format', dest='fmt', action='store', choices=('cu8', 'ci8', 'ci16_le', 'cf32_le', 'rtl', 'hackrf', 'sc16', 'float', 'fc32', 'cfile'),
                        help='Format of the samples of a recording. If not specified the file extension will be used.')
    # This only exists to support legacy scripts
    # TODO: Handle an optional "False"
    parser.add_argument('--multi-frame', dest='handle_multiple_frames_per_burst', action="store_true", default=True,
                        help='Enable handling of multiple frames per burst. Enabled by default.')
    parser.add_argument('-o', '--offline', dest='offline', action='store_true',
                        help='Enable offline mode when reading from stdin.')
    parser.add_argument('-q', '--queuelen', dest='max_queue_len', type=int, default=500,
                        help='Number of bursts that are allowed to queue up per channel. Default: 500.')
    parser.add_argument('-r', '--rate', dest='sample_rate', type=int,
                        help='Sample rate of the source or the file in Hz. Must be divisible by 100000.')
    parser.add_argument('--raw-capture', dest='raw_capture_filename', action="store",
                        help='Write a copy of the samples to a SigMF recording.')
    parser.add_argument('--generate-sigmf-meta', dest='sigmf_meta_filename', action="store",
                        help='Create a sigm-meta file based on the input format.')
    parser.add_argument('-v', '--verbose', dest='verbose', action='store_true',
                        help='Enable verbose output.')

    parser.add_argument('filename', action="store", nargs='?', default='-',
                        help='File containing either raw samples or a configuration for an SDR. If not present stdin will be used to read raw samples.')
    args = parser.parse_args()

    center = None
    fmt = None
    decimation = 1
    samples_per_symbol = 10
    cfg = {}

    if args.filename in ('-', '/dev/stdin'):
        filename = "/dev/stdin"
        cfg['source'] = 'stdin'
    else:
        filename = args.filename
        if not os.path.exists(filename):
            print("Specified input file `%s` does not exist" % filename, file=sys.stderr)
            exit(1)
        if os.path.isfile(filename):
            offline = True
        cfg['source'] = 'file'
        cfg['file'] = filename

    if filename.endswith(".conf"):
        import configparser
        config = configparser.ConfigParser()
        config.read(filename)
        cfg = {x: dict(config.items(x)) for x in config.sections()}
        if 'osmosdr-source' in cfg:
            d = cfg['osmosdr-source']
            cfg['source'] = 'osmosdr'
            offline = False
        elif 'soapy-source' in cfg:
            d = cfg['soapy-source']
            cfg['source'] = 'soapy'
            offline = False
        elif 'zeromq-sub-source' in cfg:
            d = cfg['zeromq-sub-source']
            cfg['source'] = 'zeromq-sub'
            offline = False
        elif 'uhd-source' in cfg:
            d = cfg['uhd-source']
            cfg['source'] = 'uhd'
            offline = False
        else:
            print("Unsupported .conf file format", file=sys.stderr)
            exit(1)

        sample_rate = int(d['sample_rate'])
        center = int(d['center_freq'])

        if 'demodulator' in config:
            d = config['demodulator']

            if 'samples_per_symbol' in d:
                samples_per_symbol = int(d['samples_per_symbol'])

            if 'decimation' in d:
                decimation = int(d['decimation'])

        fmt = 'float'

    sigmf = None
    if filename.endswith(".sigmf-meta") or filename.endswith(".sigmf-data"):
        import json
        meta = os.path.splitext(filename)[0] + '.sigmf-meta'
        sigmf = json.load(open(meta))
        cfg['source'] = 'file'
        cfg['file'] = os.path.splitext(filename)[0] + '.sigmf-data'

    if filename.endswith(".sigmf"):
        import json
        import tarfile
        tar = tarfile.open(filename)
        members = tar.getmembers()
        meta = [x for x in members if x.name.endswith(".sigmf-meta")][0]
        sigmf = json.load(tar.extractfile(meta))
        tardata = [x for x in members if x.name.endswith(".sigmf-data")][0]
        cfg['source'] = 'object'
        cfg['object'] = tar.extractfile(tardata)
        del cfg['file']

    if fmt == 'wav' or (fmt is None and filename.lower().endswith(".wav")):
        import wave
        wav = wave.open(filename)
        if wav.getnchannels() != 2:
            raise TypeError("wav file must have 2 channels")
        if wav.getsampwidth() == 2:
            fmt = 'ci16_le'
        else:
            raise TypeError("unknown wav format")
        sample_rate = wav.getframerate()
        cfg['source'] = 'object'
        cfg['object'] = wav.getfp()
        del cfg['file']

    if sigmf is not None:
        try:
            sample_rate = int(sigmf['global']['core:sample_rate'])
        except:
            pass
        try:
            center = int(sigmf['captures'][0]['core:frequency'])
        except:
            pass
        try:
            fmt = sigmf['global']['core:datatype']
        except:
            print("[sigmf] datatype missing.", file=sys.stderr)
        try:
            import datetime
            dt = datetime.datetime.fromisoformat(sigmf['captures'][0]['core:datetime'].replace("Z", "+00:00"))
            if args.file_info is None:
                args.file_info = f'i-{dt.timestamp():.20g}-t1'
        except:
            print("[sigmf] timestamp missing.", file=sys.stderr)

    if args.sample_rate is not None:
        sample_rate = args.sample_rate

    if args.center is not None:
        center = args.center

    if args.fmt is not None:
        fmt = args.fmt

    if args.decimation is not None:
        decimation = args.decimation

    if args.offline:
        offline = True

    # check filename extension for sample format
    extension = os.path.splitext(filename)[-1][1:]
    if fmt is None and (extension in sample_formats or extension in sample_formats.values()):
        fmt = extension

    # compatibility mappings
    if fmt in sample_formats:
        fmt = sample_formats[fmt]

    # grab things from filename as fallback
    if sample_rate is None and (g := re.search(r"-s(\d+(\.\d+)?(e[+-]\d+)?)-", filename)):
        sample_rate = int(float(g.group(1)))

    if args.file_info is None and (g := re.search(r"-t(\d{14}(\+\d{4})?)\.", filename)):
        try:
            import datetime
            dt = datetime.datetime.strptime(g.group(1), "%Y%m%d%H%M%S" + "%z" * bool(g.group(2)))
            args.file_info = f'i-{dt.timestamp():.20g}-t1'
        except ValueError:
            print("Couldn't parse timestamp in filename.", file=sys.stderr)

    if center is None and (g := re.search(r"-f(\d+(\.\d+)?(e[+-]\d+)?)-", filename)):
        center = int(float(g.group(1)))

    if args.file_info is None:
        if offline:
            args.file_info = os.path.splitext(os.path.basename(filename))[0]
        else:
            args.file_info = ''

    if sample_rate == None:
        print("Sample rate missing!", file=sys.stderr)
        exit(1)
    if sample_rate % 100000:
        print("Sample rate must be divisible by 100000!", file=sys.stderr)
        exit(1)
    if center == None:
        print("Need to specify center frequency!", file=sys.stderr)
        exit(1)
    if fmt == None:
        print("Need to specify sample format (use -h to get a list of supported formats)!", file=sys.stderr)
        exit(1)
    if decimation < 1:
        print("Decimation must be > 0", file=sys.stderr)
        exit(1)
    if decimation > 1 and decimation % 2:
        print("Decimations > 1 must be even", file=sys.stderr)
        exit(1)
    if args.debug_id is not None:
        if not os.path.isdir('/tmp/signals'):
            print("/tmp/signals directory missing!", file=sys.stderr)
            exit(1)

    if args.sigmf_meta_filename is not None:
        import json
        import datetime
        dt = None
        if args.file_info is not None and (g := re.search(r"i-(\d+(\.\d+)?)-", args.file_info)):
            try:
                dt = datetime.datetime.fromtimestamp(float(g.group(1)), datetime.timezone.utc)
            except ValueError:
                pass
        if dt is None:
            dt = datetime.datetime.fromtimestamp(os.path.getctime(filename), datetime.timezone.utc)

        data = {
            'global': {
                'core:datatype': fmt,
                'core:description': 'autogenerated for ' + os.path.basename(filename),
                'core:recorder': 'iridium-extractor',
                "core:num_channels": 1,
                'core:sample_rate': sample_rate,
                'core:version': '1.0.0'
            },
            'captures': [
                {
                    'core:datetime': dt.astimezone(datetime.timezone.utc).replace(microsecond=0, tzinfo=None).isoformat() + 'Z',
                    'core:frequency': center,
                    'core:sample_start': 0
                }
            ]
        }

        with open(args.sigmf_meta_filename, 'w') as outfile:
            json.dump(data, outfile, indent=4)

    if args.raw_capture_filename is not None:
        import json
        import datetime
        data = {
            'global': {
                'core:datatype': 'ci16_le',
                'core:description': 'iridium-extractor capture via ' + os.path.basename(filename),
                'core:recorder': 'iridium-extractor',
                "core:num_channels": 1,
                'core:sample_rate': sample_rate,
                'core:version': '1.0.0'
            },
            'captures': [
                {
                    'core:datetime': datetime.datetime.utcnow().replace(microsecond=0).isoformat() + 'Z',
                    'core:frequency': center,
                    'core:sample_start': 0
                }
            ]
        }

        if cfg['source'] == 'soapy':
            data['global']['core:hw'] = cfg['soapy-source']['driver']

        with open(args.raw_capture_filename + '.sigmf-meta', 'w') as outfile:
            json.dump(data, outfile, indent=4)

    tb = iridium.iridium_extractor_flowgraph.FlowGraph(center_frequency=center, sample_rate=sample_rate, decimation=decimation,
                                                       filename=filename, sample_format=fmt,
                                                       threshold=args.threshold, burst_width=args.burst_width,
                                                       offline=offline, max_queue_len=args.max_queue_len,
                                                       handle_multiple_frames_per_burst=args.handle_multiple_frames_per_burst,
                                                       raw_capture_filename=args.raw_capture_filename,
                                                       debug_id=args.debug_id,
                                                       max_bursts=args.max_bursts,
                                                       verbose=args.verbose,
                                                       file_info=args.file_info,
                                                       samples_per_symbol=samples_per_symbol,
                                                       config=cfg)

    statistics_thread = threading.Thread(target=print_stats, args=(tb,))
    statistics_thread.daemon = True
    statistics_thread.start()

    tb.run()

    print("Done.", file=sys.stderr)
