#!/usr/bin/python3.13

# This file is a part of sedbgmux, an open source DebugMux client.
# Copyright (c) 2022-2025  Vadim Yanitskiy <fixeria@osmocom.org>
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import logging
import argparse
import time
import cmd2
import sys

from typing import List

from sedbgmux.io import DbgMuxIOModem
from sedbgmux.dump import DumpIONative
from sedbgmux import DbgMuxPeer
from sedbgmux import DbgMuxClient
from sedbgmux import DbgMuxFrame

from sedbgmux.ch import DbgMuxConnTerminal
from sedbgmux.ch import DbgMuxConnFileLogger
from sedbgmux.ch import DbgMuxConnUdpProxy
from sedbgmux.ch import DbgMuxConnWalker
from sedbgmux.ch import DbgMuxConnHexDump

from sedbgmux.tvp.shell import TvpCommandSet

# local logger for this module
log = logging.getLogger('sedbgmux-shell')


class SEDbgMuxApp(cmd2.Cmd):
    DESC = 'DebugMux client for [Sony] Ericsson phones and modems'

    # Command categories
    CATEGORY_CONN = 'Connection management commands'
    CATEGORY_DBGMUX = 'DebugMux specific commands'

    def __init__(self, argv) -> None:
        if argv.script != '':
            log.info(f'Executing script {argv.script}')
        super().__init__(startup_script=argv.script,
                         allow_cli_args=False,
                         include_py=True)

        if argv.verbose > 0:
            logging.root.setLevel(logging.DEBUG)
            self.debug = True
        if argv.verbose_module is not None:
            logger = logging.getLogger(argv.verbose_module)
            logger.setLevel(logging.DEBUG)

        self.intro = f'Welcome to {self.DESC}!'
        self.default_category = 'Built-in commands'
        self.argv = argv

        # Init the I/O layer, DebugMux peer and client
        self.io = DbgMuxIOModem(self.argv)
        self.peer = DbgMuxPeer(self.io)
        self.client = DbgMuxClient(self.peer)

        # Optionally dump DebugMux frames to a file
        if argv.dump_file is not None:
            dump = DumpIONative(argv.dump_file, readonly=False)
            self.peer.enable_dump(dump)

        # Modem connection state
        self.set_connected(False)

        self.py_locals = { 'client' : self.client }

    def _tab_data_providers(self) -> List[cmd2.CompletionItem]:
        ''' Generate a list of DPRef values for tab-completion '''
        return [cmd2.CompletionItem(f'0x{DPRef:04x}', DPName)
                for DPRef, DPName in self.client.data_providers.items()]

    def _tab_connections(self) -> List[cmd2.CompletionItem]:
        ''' Generate a list of ConnRef values for tab-completion '''
        return [cmd2.CompletionItem(f'{ConnRef:#04x}', f'{DPRef=:#04x} {ch}')
                for ConnRef, (DPRef, ch) in self.client.active_conn.items()]

    def _tab_msg_types(self) -> List[cmd2.CompletionItem]:
        ''' Generate a list of DbgMuxFrame.MsgType values for tab-completion '''
        return [cmd2.CompletionItem(f'{val:#02x}', desc)
                for val, desc in DbgMuxFrame.MsgType.decmapping.items()]

    def _tab_modules(self) -> List[cmd2.CompletionItem]:
        ''' Generate a list of sedbgmux module names for tab-completion '''
        return sorted([name for name in sys.modules.keys() if name.startswith('sedbgmux')])

    def set_connected(self, state: bool) -> None:
        self.connected: bool = state
        if self.connected:
            self.prompt = f'DebugMux (\'{self.argv.serial_port}\')# '
            self.enable_category(self.CATEGORY_DBGMUX)
        else:
            self.prompt = f'Modem (\'{self.argv.serial_port}\')> '
            msg = 'You must be connected to use this command'
            self.disable_category(self.CATEGORY_DBGMUX, msg)

    set_log_level_parser = cmd2.Cmd2ArgumentParser()
    set_log_level_parser.add_argument('MODULE',
                                      type=str,
                                      choices_provider=_tab_modules,
                                      help='Module name (e.g. sedbgmux.peer)')
    set_log_level_parser.add_argument('LEVEL',
                                      type=str,
                                      default='DEBUG', nargs='?',
                                      choices=[*logging.getLevelNamesMapping().keys()],
                                      help='Logging level (default: %(default)s)')

    @cmd2.with_argparser(set_log_level_parser)
    def do_set_log_level(self, opts) -> None:
        ''' Set logging level '''
        logger = logging.getLogger(opts.MODULE)
        level = logging.getLevelName(opts.LEVEL)
        logger.setLevel(level)

    @cmd2.with_category(CATEGORY_CONN)
    def do_connect(self, opts) -> None:
        ''' Connect to the modem and switch it to DebugMux mode '''
        self.io.connect()
        self.peer.start()
        self.client.start()
        self.set_connected(True)

    @cmd2.with_category(CATEGORY_CONN)
    def do_disconnect(self, opts) -> None:
        ''' Disconnect from the modem '''
        self.client.stop()
        self.peer.stop()
        self.io.disconnect()
        self.set_connected(False)

    @cmd2.with_category(CATEGORY_CONN)
    def do_status(self, opts) -> None:
        ''' Print connection info and statistics '''
        if not self.connected:
            self.poutput('Not connected')
            return
        self.poutput(f'Connected to \'{self.argv.serial_port}\'')
        self.poutput(f'Baudrate: {self.argv.serial_baudrate}')
        self.poutput(f'TxCount (Ns): {self.peer.tx_count}')
        self.poutput(f'RxCount (Nr): {self.peer.rx_count}')

    show_parser = cmd2.Cmd2ArgumentParser()
    show_sparser = show_parser.add_subparsers(dest='command', required=True)
    show_sparser.add_parser('target-info')
    show_sparser.add_parser('data-providers')
    show_sparser.add_parser('connections')

    @cmd2.with_argparser(show_parser)
    @cmd2.with_category(CATEGORY_CONN)
    def do_show(self, opts) -> None:
        ''' Show various information '''
        match opts.command:
            case 'target-info':
                self.poutput('Name: ' + (self.client.target_name or '(unknown)'))
                self.poutput('IMEI: ' + (self.client.target_imei or '(unknown)'))
                self.poutput(f'ID: {self.client.target_id:#08x}')
            case 'data-providers':
                for (DPRef, DPName) in self.client.data_providers.items():
                    self.poutput(f'Data Provider ({DPRef=:#04x}): {DPName}')
            case 'connections':
                for (ConnRef, ConnInfo) in self.client.active_conn.items():
                    (DPRef, ch) = ConnInfo
                    self.poutput(f'Connection ({DPRef=:#04x}, {ConnRef=:#04x}): {ch}')
                for (DPRef, ch) in self.client.pending_conn.items():
                    self.poutput(f'Pending Connection ({DPRef=:#04x}): {ch}')

    @cmd2.with_category(CATEGORY_DBGMUX)
    def do_enquiry(self, opts) -> None:
        ''' Enquiry target identifier and available Data Providers '''
        self.client.enquiry()

    ping_parser = cmd2.Cmd2ArgumentParser()
    ping_parser.add_argument('-p', '--payload',
                             type=str, default='Knock, knock!',
                             help='Ping payload')
    ping_parser.add_argument('-t', '--timeout',
                             type=float, default=1.0,
                             help='Ping timeout (default %(default)s)')

    @cmd2.with_argparser(ping_parser)
    @cmd2.with_category(CATEGORY_DBGMUX)
    def do_ping(self, opts) -> None:
        ''' Send a Ping to the target, expect Pong '''
        self.client.ping(opts.payload, opts.timeout)

    establish_parser = cmd2.Cmd2ArgumentParser()
    establish_parser.add_argument('DPRef',
                                  type=lambda v: int(v, 16),
                                  choices_provider=_tab_data_providers,
                                  help='DPRef of a Data Provider in hex')
    establish_sparser = establish_parser.add_subparsers(dest='handler', required=True,
                                                        help='Connection handler')
    ch_terminal = establish_sparser.add_parser('terminal',
                                               help=DbgMuxConnTerminal.__doc__)
    ch_walker = establish_sparser.add_parser('walker',
                                             help=DbgMuxConnWalker.__doc__)
    ch_file_logger = establish_sparser.add_parser('file-logger',
                                                  help=DbgMuxConnFileLogger.__doc__)
    ch_file_logger.add_argument('FILE', type=argparse.FileType('ab', 0),
                                completer=cmd2.Cmd.path_complete,
                                help='File name or \'-\' for stdout')
    ch_udp_proxy = establish_sparser.add_parser('udp-proxy',
                                                help=DbgMuxConnUdpProxy.__doc__)
    ch_udp_proxy.add_argument('-la', '--local-addr', dest='laddr', type=str,
                              default=DbgMuxConnUdpProxy.LADDR_DEF[0],
                              help='Local address (default: %(default)s)')
    ch_udp_proxy.add_argument('-lp', '--local-port', dest='lport', type=int,
                              default=DbgMuxConnUdpProxy.LADDR_DEF[1],
                              help='Local port (default: %(default)s)')
    ch_udp_proxy.add_argument('-ra', '--remote-addr', dest='raddr', type=str,
                              default=DbgMuxConnUdpProxy.RADDR_DEF[0],
                              help='Remote address (default: %(default)s)')
    ch_udp_proxy.add_argument('-rp', '--remote-port', dest='rport', type=int,
                              default=DbgMuxConnUdpProxy.RADDR_DEF[1],
                              help='Remote port (default: %(default)s)')
    ch_hexdump = establish_sparser.add_parser('hexdump',
                                              help=DbgMuxConnHexDump.__doc__)
    ch_hexdump.add_argument('-ls', '--line-size',
                            type=int, default=16,
                            help='Bytes per line (default: %(default)s)')
    ch_hexdump.add_argument('-ll', '--log-level',
                            type=str, default='INFO',
                            choices=[*logging.getLevelNamesMapping().keys()],
                            help='Logging level (default: %(default)s)')

    @cmd2.with_argparser(establish_parser)
    @cmd2.with_category(CATEGORY_DBGMUX)
    def do_establish(self, opts) -> None:
        ''' Establish connections with Data Providers '''
        # select and instantiate a DbgMuxConnHandler
        match opts.handler:
            case 'terminal':
                ch = DbgMuxConnTerminal()
            case 'walker':
                ch = DbgMuxConnWalker()
            case 'file-logger':
                ch = DbgMuxConnFileLogger(opts.FILE)
            case 'udp-proxy':
                ch = DbgMuxConnUdpProxy(laddr=(opts.laddr, opts.lport),
                                        raddr=(opts.raddr, opts.rport))
            case 'hexdump':
                ch = DbgMuxConnHexDump(opts.line_size, opts.log_level)
        # establish connection with a data provider, bind selected handler
        if not self.client.conn_establish(opts.DPRef, ch):
            self.perror('Could not establish connection')
            return
        # handler specific magic here
        match opts.handler:
            case 'terminal':
                ch.attach()  # blocking until Ctrl + [CD]
                self.client.conn_terminate(ch.ConnRef)
            case 'walker':
                ch.walk()  # blocking
                self.client.conn_terminate(ch.ConnRef)

    terminate_parser = cmd2.Cmd2ArgumentParser()
    terminate_parser.add_argument('ConnRef',
                                  type=lambda v: int(v, 16),
                                  choices_provider=_tab_connections,
                                  help='ConnRef in hex')

    @cmd2.with_argparser(terminate_parser)
    @cmd2.with_category(CATEGORY_DBGMUX)
    def do_terminate(self, opts) -> None:
        ''' Terminate connection with a Data Provider '''
        if not self.client.conn_terminate(opts.ConnRef):
            self.perror('Could not terminate connection')

    @cmd2.with_category(CATEGORY_DBGMUX)
    def do_disable_crc(self, opts) -> None:
        ''' Disable CRC calculation in DebugMux frames, send/expect a dummy value '''
        self.peer.disable_crc()

    @cmd2.with_category(CATEGORY_DBGMUX)
    def do_establish_tvp(self, opts) -> None:
        ''' Establish connections with the Tvp Data Provider '''
        if not self.client.data_providers:
            self.poutput('Enquiring the target')
            self.client.enquiry()
            time.sleep(0.5)
        # establish connection
        DPRef = self.client.find_dpref_by_name('Tvp')
        cmds = TvpCommandSet(self.client)
        if not self.client.conn_establish(DPRef, cmds.ch):
            self.perror('Could not establish connection')
            return
        # enable Tvp specific commands
        self.register_command_set(cmds)
        self.tvp_cmds = cmds
        # update prompt
        self.prompt = f'DebugMux+Tvp (\'{self.argv.serial_port}\')# '

    @cmd2.with_category(CATEGORY_DBGMUX)
    def do_terminate_tvp(self, opts) -> None:
        ''' Terminate connection with the Tvp Data Provider '''
        # terminate connection
        ConnRef = self.client.find_connref_by_name('Tvp')
        self.client.conn_terminate(ConnRef)
        # unregister Tvp specific commands
        self.unregister_command_set(self.tvp_cmds)
        del self.tvp_cmds
        # update prompt
        self.prompt = f'DebugMux (\'{self.argv.serial_port}\')# '

    send_msg_parser = cmd2.Cmd2ArgumentParser()
    send_msg_parser.add_argument('MsgType',
                                 type=lambda v: int(v, 16),
                                 choices_provider=_tab_msg_types,
                                 help='Message type')
    send_msg_parser.add_argument('DataHex', type=str,
                                 nargs='?', default='',
                                 help='Message bytes (hex)')

    @cmd2.with_argparser(send_msg_parser)
    @cmd2.with_category(CATEGORY_DBGMUX)
    def do_send_msg(self, opts) -> None:
        ''' Send an arbitrary DebugMux message (raw bytes) '''
        data = bytes.fromhex(opts.DataHex)
        self.peer.send(opts.MsgType, data, raw=True)

    send_conn_data_parser = cmd2.Cmd2ArgumentParser()
    send_conn_data_parser.add_argument('ConnRef',
                                       type=lambda v: int(v, 16),
                                       choices_provider=_tab_connections,
                                       help='ConnRef in hex')
    send_conn_data_parser.add_argument('ConnDataHex', type=str,
                                       nargs='?', default='',
                                       help='ConnData bytes (hex)')

    @cmd2.with_argparser(send_conn_data_parser)
    @cmd2.with_category(CATEGORY_DBGMUX)
    def do_send_conn_data(self, opts) -> None:
        ''' Send an arbitrary ConnData (raw bytes) '''
        data = bytes.fromhex(opts.ConnDataHex)
        if not self.client.conn_send_data(opts.ConnRef, data):
            self.perror('Could not send data')

    def do_list_tvp_probes(self, opts) -> None:
        ''' List all known Tvp probes '''
        from sedbgmux.tvp import TvpProbeId
        for p in TvpProbeId:
            self.poutput(f'{p.value}\t\t{p.name}')

    def do_list_tvp_callbacks(self, opts) -> None:
        ''' List all known Tvp callbacks '''
        from sedbgmux.tvp import TvpCallbackId
        for c in TvpCallbackId:
            self.poutput(f'{c.value}\t\t{c.name}')


ap = argparse.ArgumentParser(prog='sedbgmux-shell', description=SEDbgMuxApp.DESC)

ap.add_argument('-v', '--verbose', action='count', default=0,
                help='print debug logging')
ap.add_argument('-vm', '--verbose-module', metavar='MODULE', type=str,
                help='print debug logging for a specific module')
ap.add_argument('--script', type=str, default='',
                help='script with commands to be executed automatically at start-up')

group = ap.add_argument_group('connection parameters')
group.add_argument('-p', '--serial-port', metavar='PORT', type=str, default='/dev/ttyACM0',
                   help='serial port path (default %(default)s)')
group.add_argument('--serial-baudrate', metavar='BAUDRATE', type=int, default=115200,
                   help='serial port speed (default %(default)s)')
group.add_argument('--serial-timeout', metavar='TIMEOUT', type=float, default=0.5,
                   help='serial port read timeout (default %(default)s)')
group.add_argument('--serial-flow', metavar='FLOW', type=str, default='none',
                   choices=['none',
                            *DbgMuxIOModem.MODEM_FLOW_SOFT,
                            *DbgMuxIOModem.MODEM_FLOW_HARD],
                   help='serial port flow control (default %(default)s)')
group.add_argument('--dump-file', metavar='FILE', type=str,
                   help='save Rx/Tx DebugMux frames to a file')

logging.basicConfig(
    format='\r[%(levelname)s] %(name)s:%(lineno)d %(message)s', level=logging.INFO)

if __name__ == '__main__':
    argv = ap.parse_args()
    app = SEDbgMuxApp(argv)
    sys.exit(app.cmdloop())
