#!/usr/bin/python3.13

# This file is a part of sedbgmux, an open source DebugMux client.
# Copyright (c) 2023  Vadim Yanitskiy <fixeria@osmocom.org>
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import abc
import sys
import logging
import argparse
import math
import re

from typing import Any, Optional

from sedbgmux.dump import *
from sedbgmux import DbgMuxFrame

# local logger for this module
log = logging.getLogger('sedbgmux-dump')


def dumpio_auto(fname: str, *args, **kw) -> DumpIO:
    ''' Automatic dump format detection (by filename) '''
    if fname.endswith('.socat.dump'):
        return DumpIOSocat(fname, *args, **kw)
    elif fname.endswith('.dump'):
        return DumpIONative(fname, *args, **kw)
    elif fname.endswith(('.pcap', '.pcapng', '.pcap.gz', '.pcapng.gz')):
        return DumpIOBtPcap(fname, *args, **kw)
    raise DumpIOError(f'Could not detect format automatically for {fname}')


class ConnDataDec(abc.ABC):
    ''' ConnData message decoder '''

    ARGS: dict[str, Any] = {}

    @abc.abstractmethod
    def parse(self, data: bytes, direction: str) -> None:
        ''' Parse a message, print to stdout or a file '''


class ConnDataText(ConnDataDec):
    ''' Decode as plain text '''

    ARGS = {
        'filename' : 'File for writing decoded strings (default: stdout)',
        'encoding' : 'Data encoding (default: ascii)',
        }

    def __init__(self, **kw) -> None:
        if 'filename' in kw:
            self.f = open(kw['filename'], 'w')
        else: # use stdout by default
            self.f = sys.stdout
        self.encoding = kw.get('encoding', 'ascii')

    def parse(self, data: bytes, direction: str) -> None:
        ''' Parse a message, print to stdout or a file '''
        text = data.decode(self.encoding)
        self.f.write(text)


class ConnDataTvp(ConnDataDec):
    ''' Decode as Tvp (Test and Verification Protocol) frames '''

    def __init__(self, **kw) -> None:
        from sedbgmux.tvp import TvpFrame
        from construct import Select
        self.parser = Select(TvpFrame.CommandReq,
                             TvpFrame.CommandNack,
                             TvpFrame.CommandAck,
                             TvpFrame.ProbeInject,
                             TvpFrame.Error,
                             TvpFrame.InjectDataReq)

    def parse(self, data: bytes, direction: str) -> None:
        ''' Parse a message, print to stdout or a file '''
        print(f'  Tvp {direction} frame')
        try:
            msg = self.parser.parse(data)
            print(f'  {msg}\n')
        except Exception as e:
            print(f'  Tvp decoding error: {e}\n')


class SEDbgMuxDumpApp:
    FORMATS = {
        'auto'      : dumpio_auto,
        'native'    : DumpIONative,
        'socat'     : DumpIOSocat,
        'btpcap'    : DumpIOBtPcap,
        'cdcpcap'   : DumpIOCdcPcap,
    }

    DECODERS = {
        'text'      : ConnDataText,
        'tvp'       : ConnDataTvp,
    }

    def __init__(self, argv) -> None:
        if argv.verbose > 0:
            logging.root.setLevel(logging.DEBUG)
        if argv.verbose_module is not None:
            logger = logging.getLogger(argv.verbose_module)
            logger.setLevel(logging.DEBUG)

        match argv.command:
            case 'parse':
                self.do_parse(argv)
            case 'convert':
                self.do_convert(argv)
            case 'format-info':
                self.do_format_info(argv)
            case 'decoder-info':
                self.do_decoder_info(argv)
            case 'list-formats':
                self.do_list_formats()
            case 'list-decoders':
                self.do_list_decoders()

    def parse_args(self, args: str) -> dict:
        ''' Parse comma-separated key=value arguments '''
        l = re.findall(r'(\w+)(=[^,]+)?,?', args)
        return {key : val[1:] for (key, val) in l}

    def do_parse(self, argv) -> None:
        fmt = self.FORMATS[argv.format]
        fmt_args = self.parse_args(argv.format_args)
        dump = fmt(argv.input, readonly=True, **fmt_args)

        if argv.num_records != math.inf:
            log.warning(f'Parsing up to {argv.num_records} record(s)')
        if argv.record_nr is not None:
            log.warning(f'Parsing only specific record(s): {argv.record_nr}')

        if argv.conn_data_ref is not None:
            log.warning(f'Printing only ConnData messages with '
                        f'ConnRef={argv.conn_data_ref:#04x}')
            # -cdr/--conn-data-ref implies -cd/--conn-data
            argv.conn_data = True
        elif argv.conn_data:
            log.warning('Printing only ConnData messages')

        self.cdd: Optional[ConnDataDec] = None
        if argv.conn_data_dec:
            log.info(f'Decoding ConnData messages as {argv.conn_data_dec}')
            cdd_args = self.parse_args(argv.conn_data_dec_args)
            self.cdd = self.DECODERS[argv.conn_data_dec](**cdd_args)

        num_records: int = 0
        while num_records < argv.num_records:
            try:
                record: DumpIORecord = dump.read()
                record.nr = num_records
                num_records += 1
            except DumpIOEndOfFile:
                break
            # record number filtering
            if argv.record_nr is not None:
                if record.nr not in argv.record_nr:
                    continue
            self._print_record(argv, record)

    def do_convert(self, argv) -> None:
        # input
        fmt = self.FORMATS[argv.input_format]
        fmt_args = self.parse_args(argv.input_format_args)
        di = fmt(argv.input, readonly=True, **fmt_args)
        # output
        fmt = self.FORMATS[argv.output_format]
        fmt_args = self.parse_args(argv.output_format_args)
        do = fmt(argv.output, readonly=False, **fmt_args)

        if argv.num_records != math.inf:
            log.warning(f'Converting up to {argv.num_records} record(s)')

        num_records: int = 0
        while num_records < argv.num_records:
            try:
                record: DumpIORecord = di.read()
                record.nr = num_records
                do.write(record)
                num_records += 1
            except DumpIOEndOfFile:
                break
        log.info(f'Converted {num_records} records')

    def do_format_info(self, argv) -> None:
        fmt = self.FORMATS[argv.format]
        print(f'Description: {fmt.__doc__}')
        if fmt == dumpio_auto or not fmt.ARGS:
            return
        print('Arguments:')
        for name, desc in fmt.ARGS.items():
            print(f'  {name}\t\t{desc}')

    def do_decoder_info(self, argv) -> None:
        dec = self.DECODERS[argv.decoder]
        print(f'Description: {dec.__doc__}')
        if not dec.ARGS:
            return
        print('Arguments:')
        for name, desc in dec.ARGS.items():
            print(f'  {name}\t\t{desc}')

    def do_list_formats(self) -> None:
        for name, fmt in self.FORMATS.items():
            print(f'{name}\t\t{fmt.__doc__}')

    def do_list_decoders(self) -> None:
        for name, dec in self.DECODERS.items():
            print(f'{name}\t\t{dec.__doc__}')

    def _print_record(self, argv, record: DumpIORecord) -> None:
        frame = DbgMuxFrame.Frame.parse(record.data)
        if argv.conn_data:  # print only ConnData messages
            if frame['MsgType'] != DbgMuxFrame.MsgType.ConnData:
                return
            # optionally filter by ConnRef
            if argv.conn_data_ref is not None:
                msg = DbgMuxFrame.MsgConnData.parse(frame['MsgData'])
                if msg['ConnRef'] != argv.conn_data_ref:
                    return
        # Record information
        print(record)
        # DebugMux frame header
        print(f'  DebugMux {record.direction} frame',
              '(Ns={TxCount:03d}, Nr={RxCount:03d}, fcs=0x{FCS:04x})'.format(**frame),
              frame['MsgType'], frame['MsgData'].hex())
        fcs: int = DbgMuxFrame.fcs_func(record.data[:-2])
        if fcs != frame['FCS']:
            msg = f'Indicated {frame['FCS']:#04x} != calculated {fcs:#04x}'
            if not argv.ignore_bad_fcs:
                raise DumpIOFcsError(msg)
            print(f'  Bad FCS: {msg}')
        # DebugMux frame payload
        if argv.decode_payload:
            msg = DbgMuxFrame.Msg.parse(frame['MsgData'], MsgType=frame['MsgType'])
            if msg == b'':
                return
            print(f'  {msg}')
            if frame['MsgType'] == DbgMuxFrame.MsgType.ConnData:
                if self.cdd is not None:
                    self.cdd.parse(msg['Data'], record.direction)


ap = argparse.ArgumentParser(prog='sedbgmux-dump',
                             description='DebugMux dump management utility')
sp = ap.add_subparsers(dest='command', metavar='command', required=True,
                       help='sub-command help')

ap.add_argument('-v', '--verbose', action='count', default=0,
                help='print debug logging')
ap.add_argument('-vm', '--verbose-module', metavar='MODULE', type=str,
                help='print debug logging for a specific module')

parse = sp.add_parser('parse', help='parse a dump file')
parse.add_argument('input', metavar='INPUT', type=str,
                   help='input file to be parsed')
parse.add_argument('-f', '--format', type=str, default='auto',
                   choices=[*SEDbgMuxDumpApp.FORMATS.keys()],
                   help='input file format (default: %(default)s)')
parse.add_argument('-fa', '--format-args', type=str, default='',
                   help='format specific argument(s) (example: foo=1,bar=2,zoo)')
parse.add_argument('-dp', '--decode-payload', action='store_true',
                   help='decode DebugMux frame contents')
parse.add_argument('-cd', '--conn-data', action='store_true',
                   help='show only ConnData messages')
parse.add_argument('-cdr', '--conn-data-ref',
                   type=lambda v: int(v, 16), metavar='ConnRef',
                   help='filter ConnData messages by ConnRef (implies -cd/--conn-data)')
parse.add_argument('-cdd', '--conn-data-dec',
                   choices=[*SEDbgMuxDumpApp.DECODERS.keys()],
                   help='ConnData decoder to use')
parse.add_argument('-cdda', '--conn-data-dec-args', type=str, default='',
                   help='ConnData decoder argument(s) (example: foo=1,bar=2,zoo)')
parse.add_argument('-nr', '--num-records', type=int, default=math.inf,
                   help='number of records to parse (default: all)')
parse.add_argument('-rn', '--record-nr', type=int, action='append',
                   help='parse only specific record number(s)')
parse.add_argument('--ignore-bad-fcs', action='store_true',
                   help='do not abort parsing on FCS mismatch')

convert = sp.add_parser('convert', help='convert between different formats')
convert.add_argument('input', metavar='INPUT', type=str,
                     help='input file to be converted')
convert.add_argument('output', metavar='OUTPUT', type=str,
                     help='output file')
convert.add_argument('-if', '--input-format', type=str, default='auto',
                     choices=[*SEDbgMuxDumpApp.FORMATS.keys()],
                     help='input file format')
convert.add_argument('-ifa', '--input-format-args', type=str, default='',
                     help='input file format argument(s) (example: foo=1,bar=2,zoo)')
convert.add_argument('-of', '--output-format', type=str, default='auto',
                     choices=[*SEDbgMuxDumpApp.FORMATS.keys()],
                     help='output file format')
convert.add_argument('-ofa', '--output-format-args', type=str, default='',
                     help='output file format argument(s) (example: foo=1,bar=2,zoo)')
convert.add_argument('-nr', '--num-records', type=int, default=math.inf,
                     help='number of records to convert (default: all)')

format_info = sp.add_parser('format-info', help='show format info')
format_info.add_argument(dest='format', metavar='FORMAT', type=str,
                         choices=[*SEDbgMuxDumpApp.FORMATS.keys()],
                         help='dump file format')

decoder_info = sp.add_parser('decoder-info', help='show decoder info')
decoder_info.add_argument(dest='decoder', metavar='DECODER', type=str,
                          choices=[*SEDbgMuxDumpApp.DECODERS.keys()],
                          help='ConnData decoder')

sp.add_parser('list-formats', help='list all supported formats')
sp.add_parser('list-decoders', help='list ConnData decoders')

logging.basicConfig(
    format='[%(levelname)s] %(name)s:%(lineno)d %(message)s', level=logging.INFO)

if __name__ == '__main__':
    argv = ap.parse_args()
    app = SEDbgMuxDumpApp(argv)
