#!/usr/bin/python3.13 -s

# Copyright (C) 2025 Sumit K., Shichao Wu
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
__description__ = \
"""Generate the mock data frame files for a given network of detectors.
It contains Gaussian noise, population injections, and glitches.
"""

import os
import argparse
import numpy
import pycbc
from pycbc.types import MultiDetOptionAction
import logging


parser = argparse.ArgumentParser()
pycbc.add_common_pycbc_options(parser)

parser.add_argument("--ifo-list", nargs="+", type=str,
                        help="List of detector names, e.g. H1 L1 V1")

parser.add_argument('--low-frequency-cutoff', nargs='+',
                        action=MultiDetOptionAction,
                        metavar='DETECTOR:COLUMN', type=float,
                        help='For each detector, provide minimum frequency for'
                        'the data generation.')


parser.add_argument('--channel-name', nargs='+', action=MultiDetOptionAction,
                        metavar='DETECTOR:COLUMN', type=str,
                        help='For each detector, provide channel names for the'
                        'frame files.')

parser.add_argument('--psd-model', nargs='+', action=MultiDetOptionAction,
                        metavar='DETECTOR:COLUMN', type=str,
                        help='For each detector, provide PSD option for'
                        'the noise generation.')

parser.add_argument('--fake-strain-seed', nargs='+',
                        action=MultiDetOptionAction,
                        metavar='DETECTOR:COLUMN', type=int,
                        help='For each detector, provide a random number seed')

parser.add_argument('--fake-strain-from-file', nargs='+', 
                        action=MultiDetOptionAction,
                        metavar='DETECTOR:FILE', type=str,
                        help='For each detector, provide a file (e.g. txt) to '
                             'load the fake strain from, instead of generating it.')

parser.add_argument('--gps-start-time', type=int,
                        help='GPS start time for frame files')

parser.add_argument('--gps-end-time', type=int,
                        help='GPS end time for frame files')

parser.add_argument('--sample-rate', type=float,
		        help='Sample rate of the frames to be generated')

parser.add_argument('--fake-strain-sample-rate', type=float,
                help='Sample rate for the fake strain generation (defaults to sample-rate if not set)')

parser.add_argument('--fake-strain-filter-duration', type=float,
                help='Duration of the filter used for coloring the noise')

parser.add_argument('--injection-file', type=str,
                        help='Injection file containing parameters of'
                        'injection population')

parser.add_argument('--output-path',
                        help="Path for the output frame files",
                        default='.', type=str)

parser.add_argument('--tag', help='Provide your tag for naming frame file',
                        default=None)

parser.add_argument('--len-arm', type=float, default=None,
                        help="The arm length of LISA, in the unit of 'm'")

parser.add_argument('--acc-noise-level', type=float,
                        help="The level of acceleration noise")

parser.add_argument('--oms-noise-level', type=float,
                        help="The level of OMS noise")

parser.add_argument('--tdi', type=str,
                        help="The version of TDI. Choose from '1.5' or '2.0'")

parser.add_argument('--duration', type=float,
                        help="The duration of observation, between 0 and 10,"
                             "in the unit of years")

# parse the command line
opts = parser.parse_args()


# Function to generate command line script
def create_command_opts(ifo, rseed, output_file):
    if 'LISA' in ifo:
        detector_tag = 'space'
    else:
        detector_tag = 'ground'

    # Determine the source of the strain (PSD Model OR File)
    strain_source_args = []
    
    # Priority: Check if file is provided first
    if opts.fake_strain_from_file and ifo in opts.fake_strain_from_file:
        strain_source_args.append('--fake-strain-from-file %s' % opts.fake_strain_from_file[ifo])
    # Otherwise, check if PSD model is provided
    elif opts.psd_model and ifo in opts.psd_model:
        strain_source_args.append('--fake-strain %s' % opts.psd_model[ifo])
    else:
        # This shouldn't happen if sanity check passes, but good for safety
        raise ValueError(f"No strain source (PSD or File) provided for detector {ifo}")

    if detector_tag=='ground':
        args_opt = [
             'pycbc_condition_strain',
             '--fake-strain-seed %d'%rseed,
             '--sample-rate %d'%(opts.sample_rate),
             '--gps-start-time %d'%(opts.gps_start_time),
             '--gps-end-time %d'%(opts.gps_end_time),
             '--channel-name %s:%s'%(ifo, opts.channel_name[ifo]),
             '--output-strain-file %s'%(output_file)]
        
        # Add the determined source arguments
        args_opt.extend(strain_source_args)
        
        # Pass explicit extended args if provided (no auto-fill for ground)
        if opts.fake_strain_sample_rate is not None:
             args_opt.append('--fake-strain-sample-rate %f' % opts.fake_strain_sample_rate)
        if opts.fake_strain_filter_duration is not None:
             args_opt.append('--fake-strain-filter-duration %f' % opts.fake_strain_filter_duration)

    elif detector_tag=='space':
        # Sanity Check
        if opts.len_arm is None:
            raise ValueError("Space-borne missions require len-arm")

        # If user didn't specify fake-strain-sample-rate, 
        # force it to match the global sample-rate.
        fake_rate = opts.fake_strain_sample_rate
        if fake_rate is None and opts.sample_rate is not None:
            fake_rate = opts.sample_rate
        
        # Build extra args for LISA
        extra_args = []
        if opts.len_arm is not None:
            extra_args.append('len_arm:%s' % opts.len_arm)
        if opts.acc_noise_level is not None:
            extra_args.append('acc_noise_level:%s' % opts.acc_noise_level)
        if opts.oms_noise_level is not None:
            extra_args.append('oms_noise_level:%s' % opts.oms_noise_level)
        if opts.tdi is not None:
            extra_args.append('tdi:%s' % opts.tdi)
        if opts.duration is not None:
            extra_args.append('duration:%s' % opts.duration)
        
        args_opt = [
            'pycbc_condition_strain',
            '--fake-strain-seed %d'%rseed,
            '--sample-rate %f'%(opts.sample_rate),
            '--gps-start-time %d'%(opts.gps_start_time),
            '--gps-end-time %d'%(opts.gps_end_time),
            '--channel-name %s:%s'%(ifo, opts.channel_name[ifo]),
            '--output-strain-file %s'%(output_file)
        ]

        # Add the determined source arguments
        args_opt.extend(strain_source_args)

        if fake_rate is not None:
            args_opt.append('--fake-strain-sample-rate %f' % fake_rate)
        
        if opts.fake_strain_filter_duration is not None:
            args_opt.append('--fake-strain-filter-duration %f' % opts.fake_strain_filter_duration)
        
        # Add extra args if any
        if extra_args:
            args_opt.append('--fake-strain-extra-args %s' % ' '.join(extra_args))
            
    return args_opt


# Sanity Checks
if set(opts.ifo_list) != set(opts.low_frequency_cutoff.keys()):
    raise ValueError("Low frequency cutoff for each detector is required")

# New Logic: Ensure each IFO has EITHER a PSD model OR a File
provided_psds = set(opts.psd_model.keys()) if opts.psd_model else set()
provided_files = set(opts.fake_strain_from_file.keys()) if opts.fake_strain_from_file else set()

# Check if IFOs are subset of (PSDs U Files)
if not set(opts.ifo_list).issubset(provided_psds.union(provided_files)):
    raise ValueError("For each detector, you must provide either a --psd-model or a --fake-strain-from-file.")


for ifo in opts.ifo_list:
    # Sanity Checks
    if opts.fake_strain_seed[ifo] is not None:
        if set(opts.ifo_list) != set(opts.fake_strain_seed.keys()):
            raise ValueError("Fake strain seed for each detector is required")

    logging.info("Generating strain for the Detector:%s"%ifo)
    if opts.fake_strain_seed[ifo] is None:
        rseed = numpy.random.randint(1e8)
    else:
        rseed = opts.fake_strain_seed[ifo]

    time_duration = opts.gps_end_time - opts.gps_start_time
    if opts.tag is None:
        output_file = '%s/%s-SIMULATED_STRAIN-%d-%d.gwf'%(
                    opts.output_path, ifo, opts.gps_start_time, time_duration)
    else:
        output_file = '%s/%s-SIMULATED_STRAIN-%s-%d-%d.gwf'%(
                    opts.output_path, ifo, opts.tag, opts.gps_start_time,
                    time_duration)

    # Get the base command options
    args_opt = create_command_opts(ifo, rseed, output_file)
    if opts.injection_file is not None:
        args_opt.append('--injection-file %s'%(opts.injection_file))
    if opts.low_frequency_cutoff is not None:
        args_opt.append('--fake-strain-flow %f'%(
                        opts.low_frequency_cutoff[ifo]))


    cmd = ' '.join(args_opt)
    print(cmd)
    os.system(cmd)
