#!/usr/bin/env python

# Copyright (c) 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import re
import subprocess
import sys
import zipfile
from functools import wraps

import pyredex.unpacker as unpacker
from pyredex.utils import abs_glob, make_temp_dir

ALL_COMMANDS = {}

def log(*args):
    print(*args, file=sys.stderr)

def command(name, desc):
    def wrap_function(f):
        @wraps(f)
        def wrapper(dexfiles, apk_dir, *args, **kwds):
            return f(dexfiles, apk_dir, *args, **kwds)

        wrapper.name = name
        wrapper.desc = desc
        ALL_COMMANDS[name] = wrapper
        return wrapper
    return wrap_function

# -------------------------- Command Definitions -----------------------------

@command(name='dexdump',
         desc='Calls dexdump with dexes from apk. ' \
              'Output written to stdout. Accepts all dexdump options.')
def dexdump(dexfiles, apk_dir, *args):
    dexdump_args = ['dexdump'] + list(args) + dexfiles
    subprocess.check_call(dexdump_args)


@command(name='redexdump',
         desc='Calls redexdump with dexes from apk. ' \
                'Output written to stdout. Accepts all redexdump options.')
def redexdump(dexfiles, apk_dir, *args):
    redexdump_args = ['redexdump'] + list(args) + dexfiles
    subprocess.check_call(redexdump_args)


@command(name='codesize',
         desc='Prints total and primary dex size')
def codesize(dexfiles, apk_dir, *args):
    one_mb = 1024. * 1024.
    total_size = sum(os.path.getsize(d) for d in dexfiles)
    primary_dex_size = sum(os.path.getsize(d)
            for d in dexfiles if d.endswith('classes.dex'))

    dex_count = len(dexfiles)
    log('Total code size: {:.2f}MB in {} dex files'.format(
            total_size / one_mb, dex_count))
    log('Primary dex size: {:.2f}MB'.format(primary_dex_size / one_mb))


@command(name='classes',
         desc='Prints list of classes. Use -j for java-style names.')
def classes(dexfiles, apk_dir, *args):
    java_style = '-j' in args
    proc = subprocess.Popen(['dexdump', '-f'] + dexfiles,
            stdout=subprocess.PIPE)

    pattern = b"  Class descriptor  : '([^']+)'"
    for line in proc.stdout:
        match = re.match(pattern, line)
        if match:
            classname = match.group(1).decode()
            if java_style:
                classname = classname[1:-1].replace('/', '.')
            print(classname)


# ------------------------- End Command Definitions ---------------------------

def invoke_command(apk_path, cmd, *command_args):
    extracted_apk_dir = make_temp_dir('.redex_extracted_apk')

    log('Extracting apk...')
    with zipfile.ZipFile(apk_path) as z:
        z.extractall(extracted_apk_dir)

    dex_mode = unpacker.detect_secondary_dex_mode(extracted_apk_dir)
    log('Detected dex mode ' + str(type(dex_mode).__name__))
    dex_dir = make_temp_dir('.redex_dexen')

    log('Unpacking dex files')
    dex_mode.unpackage(extracted_apk_dir, dex_dir)

    dex_files = list(abs_glob(dex_dir, '*.dex'))

    log('Running command ' + cmd.name + '...')
    cmd(dex_files, apk_path, *command_args)


def show_usage():
    log('Redex APK Utilties')
    log('Usage: apkutil <path_to_apk> <command> [command args...]\n')
    log('Available commands:')

    for name, tool in sorted(ALL_COMMANDS.items()):
        log('    {:<12} {}'.format(name, tool.desc))


if __name__ == '__main__':
    if len(sys.argv) < 3:
        show_usage()
    else:
        apk_path = sys.argv[1]
        command_name = sys.argv[2]
        if not os.path.isfile(apk_path):
            log(apk_path + ' is not a file')
        if command_name in ALL_COMMANDS:
            args = [apk_path, ALL_COMMANDS[command_name]] + sys.argv[3:]
            invoke_command(*args)
