#!/usr/bin/python3 -s
import argparse, logging, numpy as np
from ligo.segments import infinity
from pycbc.events import veto, coinc, stat
import pycbc.conversions as conv
import pycbc.version
from pycbc import io
from pycbc.events import trigger_fits as trfits
from pycbc import init_logging

parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action='count')
parser.add_argument("--version", action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument("--veto-files", nargs='+',
                    help="Optional veto file. Triggers within veto segments "
                         "contained in the file are ignored. Required if "
                         "--segment-names given.")
parser.add_argument("--segment-names", nargs='+',
                    help="Optional, name of veto segment in veto file. "
                         "Required if --veto-files given.")
parser.add_argument("--trigger-file",type=str,
                    help="File containing single-detector triggers")
parser.add_argument("--template-bank", required=True,
                    help="Template bank file in HDF format")
# produces a list of lists to allow multiple invocations and multiple args
parser.add_argument('--trigger-snr-cut',  type=float,
                    help='Only consider triggers above the given SNR.')
parser.add_argument('--cluster-window', type=float,
                    help='Window (seconds) during which to keep the trigger '
                         'with the loudest statistic value. '
                         'Default=do not cluster')
parser.add_argument('--reduced-chisq-cut', type=float,
                    help='Only consider triggers below given reduced '
                         'chisquared.')
parser.add_argument("--output-file",
                    help="File to store the candidate triggers")
stat.insert_statistic_option_group(parser)
args = parser.parse_args()

if (args.veto_files and not args.segment_names) or \
    (args.segment_names and not args.veto_files):
    raise RuntimeError('--veto-files and --segment-names are mutually required')

if not len(args.veto_files) == len(args.segment_names):
    raise RuntimeError('--segment-names are required for each --veto-files')

init_logging(args.verbose)

logging.info('Opening trigger file: %s', args.trigger_file)
trigf = io.HFile(args.trigger_file, 'r')
ifo = trigf.keys()[0]

starts = trigf[ifo + '/search/start_time'][:]
ends = trigf[ifo + '/search/end_time'][:]
segments = veto.start_end_to_segments(starts, ends)

n_tot_trigs = trigf[ifo + '/snr'].size
logging.info("%d triggers in file", n_tot_trigs)

if args.trigger_snr_cut:
    keep_idx = np.flatnonzero(trigf[ifo + '/snr'][:] >= args.trigger_snr_cut)
    snr_cut_f = ("%f" % args.trigger_snr_cut).rstrip("0").rstrip(".")
    logging.info("Cutting %d triggers with SNR < %s (%.2f%%)",
                 n_tot_trigs - keep_idx.size, snr_cut_f,
                 float(n_tot_trigs - keep_idx.size) / n_tot_trigs * 100)
if args.reduced_chisq_cut:
    n_skp_trigs = float(keep_idx.size)
    chisq = trigf[ifo + '/chisq'][:][keep_idx]
    chisq_dof = trigf[ifo + '/chisq_dof'][:][keep_idx]
    reduced_chisq = chisq / (2 * chisq_dof - 2)
    chisq_keep_idx = np.flatnonzero(reduced_chisq <= args.reduced_chisq_cut)
    chisq_cut_f = ("%f" % args.reduced_chisq_cut).rstrip("0").rstrip(".")
    logging.info("Cutting %d triggers with \chi^2 > %.f (%.2f%%)",
                 n_skp_trigs - chisq_keep_idx.size, chisq_cut_f,
                 float(n_skp_trigs - chisq_keep_idx.size) / n_skp_trigs * 100)
    # Select chisq-cut-kept idx from keep_idx
    keep_idx = keep_idx[chisq_keep_idx]

if args.veto_files:
    for veto_file, segment_name in zip(args.veto_files, args.segment_names):
        logging.info("Getting vetoed indices from file %s", veto_file)
        end_time = trigf[ifo + '/end_time'][:][keep_idx]
        veto_keep_idx, _ = veto.indices_outside_segments(end_time,
                                                         [veto_file],
                                                         segment_name=segment_name,
                                                         ifo=ifo)
        logging.info("Cutting %d triggers in vetoed segments (%.2f%%)",
                     keep_idx.size - veto_keep_idx.size,
                     float(keep_idx.size - veto_keep_idx.size) / float(keep_idx.size) * 100.)
        # Select unvetoed idx from keep_idx
        keep_idx = keep_idx[veto_keep_idx]
        veto_segs = veto.select_segments_by_definer(veto_file, ifo=ifo,
                                                    segment_name=segment_name)
        fg_segs = segments - veto_segs
else:
    fg_segs = segments

if not len(keep_idx):
    raise RuntimeError("All triggers removed by vetoes or cuts")

logging.info("Loading %d triggers", len(keep_idx))

data_init = {}
all_dsets = ['sigmasq', 'chisq', 'chisq_dof', 'coa_phase', 'end_time',
             'snr', 'template_id', 'sg_chisq']

for ds in all_dsets:
    data_init[ds] = trigf[ifo + "/" + ds][:][keep_idx]
data_init['trigger_id'] = np.arange(trigf[ifo + '/snr'].size)[keep_idx]

logging.info("Putting data into DictArray")
trigs = io.DictArray(data=data_init)
trigf.close()

logging.info('Setting up ranking method')
# Stat class instance to calculate the ranking statistic
extra_kwargs = {}
for inputstr in args.statistic_keywords:
    try:
        key, value = inputstr.split(':')
        extra_kwargs[key] = value
    except ValueError:
        err_txt = "--statistic-keywords must take input in the " \
                  "form KWARG1:VALUE1 KWARG2:VALUE2 KWARG3:VALUE3 ... " \
                  "Received {}".format(args.statistic_keywords)
        raise ValueError(err_txt)

rank_method = stat.get_statistic_from_opts(args, [ifo])

logging.info("Computing single-detector statistic")
stat = rank_method.rank_stat_single((ifo, trigs.data))

logging.info("Clustering")
if args.cluster_window:
    cid = coinc.cluster_over_time(stat, trigs.data['end_time'],
                                  args.cluster_window)
    trigs = trigs.select(cid)
    stat = stat[cid]
    logging.info("%d triggers after clustering", stat.size)

fg_time = abs(fg_segs)

data = {"stat": stat,
        "decimation_factor": np.ones_like(stat),
        "timeslide_id": np.zeros_like(stat),
        "template_id": trigs.data['template_id'],
        "%s/time" % ifo : trigs.data['end_time'],
        "%s/trigger_id" % ifo: trigs.data['trigger_id']}

logging.info("saving triggers")
f = io.HFile(args.output_file, 'w')
for key in data:
    f.create_dataset(key, data=data[key],
                     compression="gzip",
                     compression_opts=9,
                     shuffle=True)

# Store segments
f['segments/%s/start' % ifo], f['segments/%s/end' % ifo] = \
    veto.segments_to_start_end(fg_segs)
f.attrs['foreground_time'] = fg_time
f.attrs['background_time'] = fg_time
f.attrs['num_of_ifos'] = 1
f.attrs['pivot'] = ifo
f.attrs['fixed'] = ifo
f.attrs['ifos'] = ifo

# Do hierarchical removal
# h_iterations = 0
# if args.max_hierarchical_removal != 0:

logging.info("Done")
