#!/usr/bin/python3.13 -s

# Copyright 2024 Max Trevor
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""
Calculate the noise trigger rate adjustment as a function of data-quality state
for a day of PyCBC Live triggers.
"""

import logging
import argparse
import hashlib
import os
import glob

import numpy as np

from igwn_segments import segmentlist, segment

import pycbc
import pycbc.dq
from pycbc.io import HFile
from pycbc.frame.frame import read_frame, query_and_read_frame


def find_frames(ifo, frame_dir, frame_type, start, end):
    """
    Find the frame files for a given time range.
    """
    ifo_frame_type = frame_type.format(ifo=ifo)
    ifo_pattern = f'{ifo[0]}-{ifo_frame_type}_llhoft-'
    folder_pattern = ifo_pattern + '{supertime}'
    path_pattern = os.path.join(
        frame_dir,
        ifo_frame_type,
        folder_pattern,
    )

    start_round = int(start / 10000)
    end_round = int(end / 10000)
    files = []
    for t in range(start_round, end_round + 1):
        framedir = path_pattern.format(supertime=t)
        files += glob.glob(os.path.join(framedir, '*.gwf'))

    return files


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--trigger-file", required=True)
parser.add_argument("--ifo", required=True)
parser.add_argument("--gps-start-time", required=True, type=int)
parser.add_argument("--gps-end-time", required=True, type=int)
parser.add_argument("--template-bin-file", required=True)
parser.add_argument("--use-gwdatafind", action='store_true',
                    help='Use gwdatafind to find frame files. If provided, '
                         'frame-directory argument will be ignored.')
parser.add_argument("--frame-directory",
                    help='Directory containing frame files. Required if '
                         'not using gwdatafind.')
parser.add_argument("--frame-type", required=True)
parser.add_argument("--analysis-flag-name")
parser.add_argument("--dq-channel", required=True)
parser.add_argument("--dq-ok-channel", required=True)
parser.add_argument("--dq-thresh", required=True, type=float)
parser.add_argument("--dq-padding", type=float, default=1.0)
parser.add_argument("--replay-offset", type=int, default=0)
parser.add_argument("--output", required=True)
pycbc.add_common_pycbc_options(parser)
args = parser.parse_args()

pycbc.init_logging(args.verbose)

if not args.use_gwdatafind and args.frame_directory is None:
    raise ValueError('Must provide frame-directory if not using gwdatafind.')

if args.analysis_flag_name is not None:
    # Get observing segs
    ar_flag_name = args.analysis_flag_name.format(ifo=args.ifo)
    observing_flag = pycbc.dq.query_flag(
        args.ifo, ar_flag_name, args.gps_start_time, args.gps_end_time
    )
else:
    # If we are not querying the database for segments, just assume that the
    # full time is in observing mode
    observing_flag = segmentlist(
        [segment(args.gps_start_time, args.gps_end_time)]
    )

# shift observing flag to current time
observing_flag.protract(args.replay_offset)

livetime = abs(observing_flag)
logging.info(f'Found {livetime} seconds of observing time at {args.ifo}.')

# for each segment, check how much time was dq flagged
flagged_time = 0
dq_channel = args.dq_channel.format(ifo=args.ifo)
dq_ok_channel = args.dq_ok_channel.format(ifo=args.ifo)
for seg in observing_flag:
    if args.use_gwdatafind:
        ts_dq_channel, ts_dq_ok_channel = query_and_read_frame(
            args.frame_type.format(ifo=args.ifo),
            [dq_channel, dq_ok_channel],
            start_time=seg[0],
            end_time=seg[1]
        )
    else:
        frames = find_frames(
            args.ifo,
            args.frame_directory,
            args.frame_type,
            seg[0],
            seg[1],
        )

        ts_dq_channel, ts_dq_ok_channel = read_frame(
            frames, [dq_channel, dq_ok_channel],
            start_time=seg[0], end_time=seg[1]
        )

    # build segmentlists:
    # Is the channel actually working?
    ts_dq_ok_channel._data = ts_dq_ok_channel._data.astype(bool)
    dq_ok_segs = ts_dq_ok_channel.bool_to_segmentlist()

    # Is the channel above threshold?
    dq_mask = (ts_dq_channel >= args.dq_thresh)
    # If not, set it to zero / one accordingly
    ts_dq_channel._data = dq_mask
    dq_segs = ts_dq_channel.bool_to_segmentlist()

    # these are the times where dq channel is working, and
    # the dq channel is above threshold
    valid_bad_dq = dq_segs & dq_ok_segs

    # pad / expand segments outwards by dq_padding seconds
    valid_bad_dq.protract(args.dq_padding)
    valid_bad_dq.coalesce()

    # intersection with observing_flag and add livetime
    flagged_time += abs(valid_bad_dq & observing_flag)

logging.info(f'Found {flagged_time} seconds of dq flagged time at {args.ifo}.')

bg_livetime = livetime - flagged_time
state_time = np.array([bg_livetime, flagged_time])

# read in template bins
logging.info(f'Reading template bins from {args.template_bin_file}')
template_bins = {}
num_templates = 0
with HFile(args.template_bin_file, 'r') as bin_file:
    f_lower = bin_file.attrs['f_lower']
    bank_file = bin_file.attrs['bank_file']
    bin_string = bin_file.attrs['background_bins']

    # only ever one group in this file
    grp = bin_file[list(bin_file.keys())[0]]
    num_bins = len(grp.keys())
    for k in grp.keys():
        template_bins[k] = grp[k]['tids'][:]
        num_templates += len(template_bins[k])

# for each bin, get total number of triggers and number of dq flagged triggers
bin_total_triggers = np.zeros(num_bins)
bin_dq_triggers = np.zeros(num_bins)

logging.info(f'Reading triggers from {args.trigger_file}')
with HFile(args.trigger_file, 'r') as trigf:
    for bin_name in template_bins.keys():
        # bins are named as 'bin{bin_num}'
        bin_num = int(bin_name[3:])
        template_ids = template_bins[bin_name]
        for template_id in template_ids:
            # get dq states for all triggers with this template
            # dq state is either 0 or 1 for each trigger
            dq_key = trigf[f'{args.ifo}/dq_state_template'][template_id]
            dq_states = trigf[f'{args.ifo}/dq_state'][dq_key]

            # update trigger counts
            bin_total_triggers[bin_num] += len(dq_states)
            bin_dq_triggers[bin_num] += np.sum(dq_states)

# write outputs to file
logging.info(f'Writing results to {args.output}')
with HFile(args.output, 'w') as f:
    ifo_group = f.create_group(args.ifo)
    ifo_group.create_dataset('observing_livetime', data=livetime)
    ifo_group.create_dataset('dq_flag_livetime', data=flagged_time)
    bin_group = ifo_group.create_group('bins')
    for bin_name in template_bins.keys():
        bin_num = int(bin_name[3:])
        bgrp = bin_group.create_group(bin_name)
        bgrp.create_dataset('tids', data=template_bins[bin_name])
        bgrp.create_dataset('total_triggers', data=bin_total_triggers[bin_num])
        bgrp.create_dataset('dq_triggers', data=bin_dq_triggers[bin_num])

        bg_triggers = bin_total_triggers[bin_num] - bin_dq_triggers[bin_num]
        num_trigs = np.array([bg_triggers, bin_dq_triggers[bin_num]])
        trig_rates = num_trigs / state_time
        mean_rate = bin_total_triggers[bin_num] / livetime
        normalized_rates = trig_rates / mean_rate
        bgrp.create_dataset('dq_rates', data=normalized_rates)

    f.attrs['dq_thresh'] = args.dq_thresh
    f.attrs['dq_channel'] = dq_channel
    f.attrs['dq_ok_channel'] = dq_ok_channel
    f.attrs['gps_start_time'] = args.gps_start_time
    f.attrs['gps_end_time'] = args.gps_end_time
    f.attrs['f_lower'] = f_lower
    f.attrs['bank_file'] = bank_file
    f.attrs['background_bins'] = bin_string

    # hash is used to check if different files have compatible settings
    settings_to_hash = [args.dq_thresh, dq_channel, dq_ok_channel,
                        f_lower, bank_file, bin_string]
    setting_str = ' '.join([str(s) for s in settings_to_hash])
    hash_object = hashlib.sha256(setting_str.encode())
    f.attrs['settings_hash'] = hash_object.hexdigest()

logging.info('Done')