#!/usr/bin/python3.13 -s

# Copyright (C) 2017 Alex Nitz, Duncan Macleod
#               2022 Shichao Wu
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Generate a bank of templates using a brute force stochastic method.
"""
import numpy
import logging
import signal
import tqdm
import os
import argparse
import numpy.random
from scipy.stats import gaussian_kde

import pycbc.waveform, pycbc.filter, pycbc.types, pycbc.psd, pycbc.fft, pycbc.conversions
import pycbc.pool
from pycbc import transforms
from pycbc.waveform.spa_tmplt import spa_length_in_time
from pycbc.distributions import read_params_from_config
from pycbc.distributions.utils import draw_samples_from_config, prior_from_config
from pycbc.io import HFile

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--output-file', required=True,
    help='Output file name for template bank.')
parser.add_argument('--input-file', nargs='*', default=[],
    help='Bank to use as an initial set of starting samples.')
parser.add_argument('--keep-entire-input-file', help="All of the templates in the input file will be kept in the output file, even if they do not pass normal mismatch criteria.", action='store_true')
parser.add_argument('--input-config', required=True,
    help='Draw parameters from the given configure file.')
parser.add_argument('--minimal-match', default=0.97, type=float)
parser.add_argument('--buffer-length', default=2, type=float,
    help='size of waveform buffer in seconds')
parser.add_argument('--use-td-waveform', action='store_true',
    help='Generate waveform in the time domain (default is frequency domain).')
parser.add_argument('--full-resolution-buffer-length', default=None, type=float,
    help='Size of the waveform buffer in seconds for generating time-domain signals at full resolution before conversion to the frequency domain.')
parser.add_argument('--max-signal-length', type= float,
                    help="When specified, it cuts the maximum length of the waveform model to the lengh provided")
parser.add_argument('--sample-rate', default=2048, type=float,
    help='sample rate in seconds')
parser.add_argument('--low-frequency-cutoff', default=20.0, type=float)
parser.add_argument('--enable-sigma-bound', action='store_true')
parser.add_argument('--tau0-threshold', type=float)
parser.add_argument('--permissive', action='store_true',
    help='Allow waveform generator to fail.')
parser.add_argument('--placement-iterations', default=1000, type=int,
    help='Specify the number of attempts the bank should make when placing points. Use this option if the bank fails to place any points.')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--tolerance', type=float)
parser.add_argument('--use-cross', action='store_true')
parser.add_argument('--max-q', type=float)
parser.add_argument('--tau0-crawl', type=float)
parser.add_argument('--tau0-start', type=float)
parser.add_argument('--tau0-end', type=float)
parser.add_argument('--tau0-cutoff-frequency', type=float, default=15.0)
parser.add_argument('--nprocesses', type=int, default=1,
    help='Number of processes to use for waveform generation parallelization. If not given then only a single core will be used.')
parser.add_argument('--parallel-check', action='store_true', help="Do bank checking parallel, note that this means that proposals WILL NOT be checked against each other.")
parser.add_argument('--max-connections', type=int, help="Maximum number of matches to store with each template", default=numpy.inf)
pycbc.psd.insert_psd_option_group(parser)
parser.add_argument('--use-trimmed-buffer', action='store_true',
    help=('When specified, the match calculation will use only the first and last 100 samples '
        'of the waveform instead of the full buffer. Enabling this option makes match computation faster.'))


args = parser.parse_args()

pycbc.init_logging(args.verbose)
numpy.random.seed(args.seed)

config_parser = pycbc.types.config.InterpolatingConfigParser()
file = open(args.input_config, 'r')
config_parser.read_file(file)
file.close()

variable_args, static_args = read_params_from_config(
    config_parser, prior_section='prior',
    vargs_section='variable_params',
    sargs_section='static_params')

if any(config_parser.get_subsections('waveform_transforms')):
    waveform_transforms = transforms.read_transforms_from_config(
            config_parser, 'waveform_transforms')
else:
    waveform_transforms = None

dists_joint = prior_from_config(cp=config_parser)

def get_stable_match(p_inj, p_tmplt):
    """Iteratively doubles buffer length until the match stabilizes."""
    frbl = args.full_resolution_buffer_length
    current_buflen = args.buffer_length if frbl is None else frbl
    last_match = -1.0
    
    # We loop until convergence or a safety cap (e.g., 64s)
    while (last_match == -1.0) or (current_buflen <= 64):
        # Create a transient generator for this specific resolution
        tmp_gen = GenUniformWaveform(current_buflen, args.sample_rate, args.low_frequency_cutoff)
        h_inj = tmp_gen.generate(**p_inj)
        h_tmplt = tmp_gen.generate(**p_tmplt)
        current_match = tmp_gen.match(h_inj, h_tmplt)

        # If the change is sub-leading, we are done
        if (args.full_resolution_buffer_length) is not None or (abs(current_match - last_match) < 1e-4):
            return current_match, current_buflen
            
        last_match = current_match
        current_buflen *= 2
        
    return last_match, current_buflen
    
    
class Shrinker(object):
    def __init__(self, data):
        self.data = data

    def pop(self):
        if len(self.data) == 0:
            return None
        l = self.data[-1]
        self.data = self.data[:-1]
        return l

class TriangleBank(object):
    """ A bank of templates that uses the triangle inequality to estimate
    matches based on prior ones.
    """
    def __init__(self, p=None):
        self.waveforms = p if p is not None else []
        self.tbins = {}
        self.max_matches = []

    def __len__(self):
        return len(self.waveforms)

    def activelen(self):
        i = 0
        for w in self.waveforms:
            if isinstance(w, pycbc.types.FrequencySeries):
                i += 1
        return i

    def insert(self, hp):
        self.waveforms.append(hp)

        for b in [hp.tbin - 1, hp.tbin, hp.tbin + 1]:
            if b in self.tbins:
                self.tbins[b].append(len(self)-1)
            else:
                self.tbins[b] = [len(self)-1]

    def __getitem__(self, index):
        return self.waveforms[index]

    def keys(self):
        return self.waveforms[0].params.keys()

    def key(self, k):
        return numpy.array([p.params[k] for p in self.waveforms])

    def sigma_match_bound(self, sig):
        if not hasattr(self, 'sigma'):
            self.sigma = None
        if self.sigma is None or len(self.sigma) != len(self):
            self.sigma = numpy.array([h.s for h in bank.waveforms])
        return numpy.minimum(sig / self.sigma, self.sigma / sig)

    def range(self):
        if not hasattr(self, 'r'):
            self.r = None
        if self.r is None or len(self.r) != len(self):
            self.r = numpy.arange(0, len(self))
        return self.r

    def culltau0(self, threshold):
        cull = numpy.where(self.tau0() < threshold)[0]

        class dumb(object):
            pass
        for c in cull:
            d = dumb()
            d.tau0 = self.waveforms[c].tau0
            d.params = self.waveforms[c].params
            d.s = self.waveforms[c].s
            self.waveforms[c] = d

    def tau0(self):
        if not hasattr(self, 't0'):
            self.t0 = None
        if self.t0 is None or len(self.t0) != len(self):
            self.t0 = numpy.array([h.tau0 for h in self])
        return self.t0

    def __contains__(self, hp):
        mmax = 0
        mnum = 0
        #Apply sigmas maximal match.
        if args.enable_sigma_bound:
            matches = self.sigma_match_bound(hp.s)
            r = self.range()[matches > hp.threshold]
        else:
            matches = numpy.ones(len(self))
            r = self.range()

        msig = len(r)

        #Apply tau0 threshold
        if args.tau0_threshold:
            hp.tau0 = pycbc.conversions.tau0_from_mass1_mass2(
                                            hp.params['mass1'],
                                            hp.params['mass2'],
                                            args.tau0_cutoff_frequency)
            hp.tbin = int(hp.tau0 / args.tau0_threshold)

            if hp.tbin in self.tbins:
                r = numpy.array(self.tbins[hp.tbin])
            else:
                r = r[:0]

        mtau = len(r)

        # Try to do some actual matches
        inc = Shrinker(r*1)
        while 1:
            j = inc.pop()
            if j is None:
                msort = matches[r].argsort()
                
                msorted = matches[r][msort]
                rsorted = r[msort]
                keep = numpy.ones(len(msorted), dtype=bool)
                if args.max_connections < len(keep):
                    #keep[args.max_connections//2: -args.max_connections//2] = False
                    keep[args.max_connections:] = False
                
                hp.matches = msorted[keep].copy()
                hp.indices = rsorted[keep].copy()

                logging.info("TADD MaxMatch:%0.3f Size:%i "
                             "AfterSigma:%i AfterTau0:%i Matches:%i"
                              % (mmax, len(self), msig, mtau, mnum))
                hp.max_match = mmax
                return False

            hc = self[j]

            m = hp.gen.match(hp, hc)
            matches[j] = m
            mnum += 1

            # Update bounding match values, apply triangle inequality
            maxmatches = hc.matches - m + 1.10
            update = numpy.where(maxmatches < matches[hc.indices])[0]
            matches[hc.indices[update]] = maxmatches[update]

            # Update where to calculate matches
            skip_threshold = 1 - (1 - hp.threshold) * 2.0
            inc.data = inc.data[matches[inc.data] > skip_threshold]

            if m > hp.threshold:
                return True
            if m > mmax:
                mmax = m

    def check_params(self, gen, params, threshold,
                           force_add=False,
                           parallel_check=False, progress=False):
        num_added = 0
        total_num = len(tuple(params.values())[0])

        idxs = numpy.arange(0, total_num)
        chunks = numpy.array_split(idxs, max(total_num // args.nprocesses // 4, 1))

        for chunk in chunks:
            waveform_cache = []
            pool = pycbc.pool.choose_pool(args.nprocesses)

            for return_wf in tqdm.tqdm(pool.imap_unordered(
                    wf_wrapper,
                    (({k: params[k][idx] for k in params}, parallel_check, threshold) for idx in chunk)),
                    total=total_num, disable=not progress):
                waveform_cache += [return_wf]


            for hp in waveform_cache:
                if hp is not None:
                    if hp.checked is None:
                        hp.gen = gen
                        hp.checked = hp not in self
 
                    if hp.checked:
                        self.max_matches.append(hp.max_match)
                
                    if hp.checked or force_add:
                        num_added += 1
                        self.insert(hp)
                else:
                    logging.info("Waveform generation failed!")
                    continue

            pool.close_pool()
            del pool

        return bank, num_added / total_num

def decimate_frequency_domain(template, target_df):
    """
    Returns a frequency-domain waveform resampled to a lower frequency resolution 
    (delta_f) by decimation.

    Parameters
    ----------
    template : pycbc.types.FrequencySeries
        The input frequency-domain signal to be decimated.
    target_df : float
        The target frequency resolution (delta_f) for the decimated signal.

    Returns
    ----------
    decimated_template : pycbc.types.FrequencySeries
        A new FrequencySeries object with the decimated data and the specified 
        target delta_f.
    """
    # Calculate the decimation factor
    decimation_factor = int(target_df / template.delta_f)

    if decimation_factor < 1:
        raise ValueError("Target delta_f must be greater than or equal to the original delta_f.")

    # Decimate the data by selecting every 'decimation_factor'-th point
    decimated_signal = template.data[::decimation_factor]

    # Create a new FrequencySeries object with the decimated data and the target delta_f
    decimated_template = pycbc.types.FrequencySeries(decimated_signal, delta_f=target_df)
    return decimated_template

def handle_exit_signal(signum, frame):
    logging.warning(f"Signal {signum} received. Triggering emergency save...")
    save_bank()
    # Force exit so the script doesn't try to resume the loops
    os._exit(0)

def save_bank(): # Use a local name
    logging.info("Saving current bank to %s", args.output_file)
    with HFile(args.output_file, 'w') as o:
        o.attrs['minimal_match'] = args.minimal_match
        for k in bank.keys():
            val = bank.key(k)
            if val.dtype.char == 'U':
                val = val.astype('bytes')
            o[k] = val
        o['max_matches'] = numpy.array(bank.max_matches)
    logging.info("Save complete.")

class GenUniformWaveform(object):
    def __init__(self, buffer_length, sample_rate, f_lower):
        self.f_lower = f_lower
        self.delta_f = 1.0 / buffer_length
        tlen = int(buffer_length * sample_rate)
        self.flen = tlen // 2 + 1
        psd = pycbc.psd.from_cli(args, self.flen, self.delta_f, self.f_lower)
        self.kmin = int(f_lower * buffer_length)
        self.w = ((1.0 / psd[self.kmin:-1]) ** 0.5).astype(numpy.float32)
        qtilde = pycbc.types.zeros(tlen, numpy.complex64)
        q = pycbc.types.zeros(tlen, numpy.complex64)
        self.qtilde_view = qtilde[self.kmin:self.flen - 1]
        self.ifft = pycbc.fft.IFFT(qtilde, q)

        self.md = q
        self.md2 = numpy.zeros(20)

        if args.use_trimmed_buffer:
            self.md = q._data[-100:]
            self.md2 = q._data[0:100]

    def generate(self, **kwds):
        if args.max_signal_length is not None:
                flow = numpy.arange(self.f_lower, 100, .1)[::-1]
                length = spa_length_in_time(mass1=kwds['mass1'], mass2=kwds['mass2'], f_lower=flow, phase_order=-1)
                maxlen = args.max_signal_length
                x = numpy.searchsorted(length, maxlen) - 1
                l = length[x]
                f = flow[x]
        else:
                f = self.f_lower

        if 'f_lower' not in kwds:
            kwds['f_lower'] = f

        if hasattr(kwds['approximant'], 'decode'):
            kwds['approximant'] = kwds['approximant'].decode()

        if args.full_resolution_buffer_length is not None: 
            buff_len =  args.full_resolution_buffer_length
        else:
            buff_len = 1.0 / self.delta_f

        if args.use_td_waveform and kwds['approximant'] in pycbc.waveform.td_approximants():

            hp, hc = pycbc.waveform.get_td_waveform(delta_t=1.0 / args.sample_rate,
                                                **kwds)

            hp = hp.to_frequencyseries(delta_f = 1.0 / buff_len)
            hc = hc.to_frequencyseries(delta_f = 1.0 / buff_len)

        elif kwds['approximant'] in pycbc.waveform.fd_approximants():

            hp, hc = pycbc.waveform.get_fd_waveform(delta_f = 1.0 / buff_len,
                                                **kwds)

            if  args.use_cross:
                hp = hc

            if 'fratio' in kwds:
                hp = hc * kwds['fratio'] + hp * (1 - kwds['fratio'])

        else:
            dt = 1.0 / args.sample_rate
            hp = pycbc.waveform.get_waveform_filter(
                        pycbc.types.zeros(buff_len * args.sample_rate // 2 + 1,
                        dtype=numpy.complex64),
                        delta_f=1.0 / buff_len, delta_t=dt,
                        **kwds)

        if args.full_resolution_buffer_length is not None:
            # Decimate the generated signal to a reduced frequency resolution
            hp = decimate_frequency_domain(hp, self.delta_f)

        hp.resize(self.flen)
        hp = hp.astype(numpy.complex64)
        
        hp[self.kmin:-1] *= self.w
        s = float(1.0 / pycbc.filter.sigmasq(hp,
                                             low_frequency_cutoff=f) ** 0.5)
        hp *= s
        hp.params = kwds
        hp.view = hp[self.kmin:-1]
        hp.s = (1.0 / s) ** 2.0
        return hp

    def match(self, hp, hc):
        pycbc.filter.correlate(hp.view, hc.view, self.qtilde_view)
        self.ifft.execute()
        m = max(abs(self.md).max(), abs(self.md2).max())
        return m * 4.0 * self.delta_f

r = 0
if not args.tolerance:
    tolerance = (1 - args.minimal_match) / 10
else:
    tolerance = args.tolerance

size = int(1.0 / tolerance)

gen = GenUniformWaveform(args.buffer_length,
    args.sample_rate, args.low_frequency_cutoff)
bank = TriangleBank()

def silent_child_exit(signum, frame):
    # The child workers execute this branch
    # os._exit(0) ensures they die instantly and quietly
    os._exit(0)

def wf_wrapper(args):
    p, parallel_check, threshold = args
    
    signal.signal(signal.SIGINT, silent_child_exit)
    signal.signal(signal.SIGTERM, silent_child_exit)
    try:
        hp = gen.generate(**p)
        hp.checked = None
        hp.threshold = threshold   
        if parallel_check:
            hp.gen = gen
            hp.checked = hp not in bank    
            hp.gen = None
        return hp
    except Exception as e:
        print(e)
        return None

if len(args.input_file) > 0:
    total_raw_templates = 0
    for input_file in args.input_file:
        logging.info("Loading: %s", input_file)
        with HFile(input_file, 'r') as f:
            params = {k: f[k][:] for k in f.keys() if k not in ['max_matches']}
            total_raw_templates += len(params[list(params.keys())[0]])
            bank, _ = bank.check_params(gen, params, args.minimal_match,
                                        force_add=args.keep_entire_input_file,
                                        parallel_check=args.parallel_check,
                                        progress=True)
            reduction = total_raw_templates - len(bank)
            logging.info("Combined Size: %i (Naive Total: %i, Removed: %i)", 
                         len(bank), total_raw_templates, reduction)

def draw(rtype):
    if rtype == 'uniform':
        # `draw_samples_from_config` has its own fixed seed, so must overwrite it.
        random_seed = numpy.random.randint(low=0, high=2**32-1)
        samples = draw_samples_from_config(args.input_config, size, random_seed)
        params = {name: samples[name] for name in samples.fieldnames}
        # Add `static_args` back.
        if static_args is not None:
            for k in static_args.keys():
                params[k] = numpy.array([static_args[k]]*size)

    elif rtype == 'kde':
        trail = 300
        if trail > len(bank):
            trail = len(bank)
        p = variable_args
        bdata = numpy.array([bank.key(k)[-trail:] for k in p])
        kde = gaussian_kde(bdata)
        points = kde.resample(size=size)
        params = {k: v for k, v in zip(p, points)}

        # Add `static_args` back, some transformations may need them.
        if static_args is not None:
            for k in static_args.keys():
                params[k] = numpy.array([static_args[k]]*size)

        # Apply `waveform_transforms` defined in the .ini file to samples.
        if waveform_transforms is not None:
            params = transforms.apply_transforms(params, waveform_transforms)

    # Filter out stuff (kde method may also generate samples outside boundaries).
    l = dists_joint.contains(params)
    params = {k: params[k][l] for k in params}
    return params

def cdraw(rtype, ts, te):
    from pycbc.conversions import tau0_from_mass1_mass2

    p = draw(rtype)
    if  len(p[list(p.keys())[0]]) > 0:
        t = tau0_from_mass1_mass2(p['mass1'], p['mass2'],
                                  args.tau0_cutoff_frequency)
        l = (t < te) & (t > ts)
        p = {k: p[k][l] for k in p}

    i = 0
    while len(p[list(p.keys())[0]]) < size:

        tp = draw(rtype)
        if  len(tp[list(tp.keys())[0]]) > 0:
            t = tau0_from_mass1_mass2(tp['mass1'], tp['mass2'],
                                      args.tau0_cutoff_frequency)
            l = (t < te) & (t > ts)
            tp = {k: tp[k][l] for k in tp}
            
        p = {k: numpy.concatenate([p[k], tp[k]]) for k in p}

        i += 1
        if i > args.placement_iterations:
            break

    if len(p[list(p.keys())[0]]) == 0:
        return None

    return p

signal.signal(signal.SIGINT, handle_exit_signal)
signal.signal(signal.SIGTERM, handle_exit_signal)

tau0s = args.tau0_start
tau0e = tau0s + args.tau0_crawl

go = True

region = 0
while tau0s < args.tau0_end:

    tau0e = min(tau0e, args.tau0_end)

    conv = 1
    r = 0
    while conv > tolerance:
        # Standard Round
        r += 1
        params = cdraw('uniform', tau0s, tau0e)
        if params is None:
            if len(bank) > 0:
                go = False
            break

        blen = len(bank)
        bank, uconv = bank.check_params(gen, params, args.minimal_match,
                                        parallel_check=args.parallel_check)
        logging.info("%s: Round (U): %s Size: %s conv: %s added: %s",
                     region, r, len(bank), uconv, len(bank) - blen)
        if r > 10:
            conv = uconv
        kloop = 0

        while ((kloop == 0) or (kconv / okconv) > .5) and len(bank) > 10:
            r += 1
            kloop += 1
            params = cdraw('kde', tau0s, tau0e)
            blen = len(bank)
            bank, kconv = bank.check_params(gen, params, args.minimal_match,
                                            parallel_check=args.parallel_check)

            trail_matches = numpy.array(bank.max_matches[int(len(bank.max_matches)*0.9):])
            ave = numpy.mean(trail_matches)
            logging.info("%s: Round (K) (%s): %s Size: %s conv: %0.4f added: %s Trail Ave: %0.4f Trail Min: %0.4f",
                         region, kloop, r, len(bank), kconv, len(bank) - blen,
                         ave, trail_matches.min())


            if uconv:
                logging.info('Ratio of convergences: %2.3f' % (kconv / (uconv)))
                logging.info('Progress: {:.0%} completed'.format(tau0e/args.tau0_end))

            if kloop == 1:
                okconv = kconv

            if kconv <= tolerance:
                conv = kconv
                break

    bank.culltau0(tau0s - args.tau0_threshold * 2.0)
    logging.info("Region Done %3.1f-%3.1f, %s stored",
                 tau0s, tau0e, bank.activelen())
    region += 1
    tau0s += args.tau0_crawl / 2
    tau0e += args.tau0_crawl / 2

save_bank()
