#!/usr/bin/python3.13 -s
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Produce the sky grid plot for the triggered search (PyGRB)."""

# =============================================================================
# Preamble
# =============================================================================

import sys
import os
import logging
import numpy
import h5py
from matplotlib import pyplot as plt
import matplotlib.colors as colors
from matplotlib import rc
from matplotlib.ticker import MaxNLocator
import pycbc.version
from pycbc import init_logging
from pycbc.results import save_fig_with_metadata
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.detector import Detector
import pycbc.distributions

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_skygrid"


def define_rows_cols_subplot(nplots):
    cols = int(numpy.ceil(numpy.sqrt(nplots)))
    rows = int(numpy.ceil(nplots/cols))
    return cols, rows

def ra_to_ra_mollweide(ra):
    mollweide_ra = numpy.remainder(ra + 2*numpy.pi, 2*numpy.pi)
    mollweide_ra[mollweide_ra > numpy.pi] -= 2*numpy.pi
    return mollweide_ra

# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("--sky-grid", required=True,
                    help="The location of the sky grid file")
parser.add_argument("--num-density-bins", default=1200, type=int,
                    help="Bins for the input distribution plotting")
parser.add_argument("--num-density-samples", default=2000000, type=int,
                    help="Samples for plotting the input distribution")
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s:%(levelname)s : %(message)s")

sky_grid = os.path.abspath(opts.sky_grid)
outfile = opts.output_file
if opts.plot_title is None:
    opts.plot_title = 'PyGRB sky grid'

logging.info("Imported and ready to go.")

# Set output directories
outdirs = [os.path.split(os.path.abspath(outfile))[0]]
for outdir in outdirs:
    if not os.path.isdir(outdir):
        os.makedirs(outdir)

#Extract all informations from sky grid
with h5py.File(sky_grid, "r") as f:
    ra, dec = f['ra'][:], f['dec'][:]
    dist = f.attrs['input_distribution']
    input_dist = eval("pycbc.distributions."+dist)
    samples = input_dist.rvs(opts.num_density_samples)
    input_ra, input_dec = samples['ra'], samples['dec'] 
    ifos = f.attrs["detectors"]
    detectors = [Detector(d) for d in ifos]
    gps_time = f.attrs['ref_gps_time']

xlabel = "Right ascension [deg]"
ylabel = "Declination [deg]"

uni_points = pycbc.distributions.UniformSky().rvs(1000000)

#Convert ra from [0,2*pi] to [-pi,pi] for the mollweide plot 
uni_ra = ra_to_ra_mollweide(uni_points["ra"])
uni_dec = uni_points["dec"]

in_ra = ra_to_ra_mollweide(input_ra)

grb_ra = ra_to_ra_mollweide(ra)

num_points = len(grb_ra)

#Generation of the input density for the mollweide plot 
input_density, ra_edge, dec_edge = numpy.histogram2d(
    in_ra,input_dec, bins=opts.num_density_bins, 
    range=[[-numpy.pi, numpy.pi], [-numpy.pi/2, numpy.pi/2]]
)

cols , rows = define_rows_cols_subplot(len(detectors)+2)
fig, ax = plt.subplots(nrows=rows, ncols=cols,subplot_kw=dict(projection="mollweide"), figsize=(20,20))

cmap = 'Oranges'
skygrid_color = 'black'

#Sky grid over input distribution plot
levels = MaxNLocator(nbins=100).tick_values(0, input_density.max())
ax[0,0].contourf(ra_edge[:-1],dec_edge[:-1],input_density.T, levels=levels,cmap=cmap)
ax[0,0].plot(grb_ra, dec, 'x', c=skygrid_color)
cb = fig.colorbar(
    plt.cm.ScalarMappable(colors.Normalize(vmin=0, vmax=input_density.max()), cmap=cmap), 
    location="bottom", ax=ax[0,0], label="Probability density [a.u.]"
)
ax[0,0].set_xlabel(xlabel)
ax[0,0].set_ylabel(ylabel)
ax[0,0].set_title(f"Sky grid ({num_points} points) over input distribution")
ax[0,0].grid(True)

#Hide the second "plot"
ax[0,1].axis("Off")

#Zoomed in plot
ax[0,1] = plt.subplot(rows,cols,2, projection="rectilinear")
ax[0,1].set_xlim(numpy.degrees(grb_ra.min())-5, numpy.degrees(grb_ra.max())+5)
ax[0,1].set_ylim(numpy.degrees(dec.min())-5, numpy.degrees(dec.max())+5)
ax[0,1].contourf(numpy.degrees(ra_edge[:-1]),numpy.degrees(dec_edge[:-1]),input_density.T, levels=levels, cmap=cmap)
ax[0,1].plot(numpy.degrees(grb_ra), numpy.degrees(dec), 'x', c=skygrid_color)
ax[0,1].set_xlabel(xlabel)
ax[0,1].set_ylabel(ylabel)
ax[0,1].set_title(f"Sky grid ({num_points} points) over input distribution (zoom)")
ax[0,1].grid(True)

if len(detectors) <= 2:
    idx_col, idx_row = 0, 1 
else:
    idx_col, idx_row = 2, 0

#Sky grid over Antenna pattern plots
for det in detectors:
    if idx_col >= cols:
        idx_col = 0
        idx_row += 1
    ant_pat = det.antenna_pattern(uni_points["ra"], uni_points["dec"], 0, t_gps=gps_time)
    quad_ant = numpy.sqrt(ant_pat[0]**2 + ant_pat[1]**2)
    logging.info('Plotting %s', os.path.basename(outfile))
    sc = ax[idx_row,idx_col].scatter(uni_ra, uni_dec, c=quad_ant, marker='.', cmap=cmap)
    ax[idx_row,idx_col].scatter((grb_ra), (dec),c=skygrid_color, marker='x')
    ax[idx_row,idx_col].set_title(f"Sky grid over {det.name} antenna pattern")
    ax[idx_row,idx_col].set_xlabel(xlabel)
    ax[idx_row,idx_col].set_ylabel(ylabel)
    plt.colorbar(sc, ax=ax[idx_row,idx_col], label=r"$\sqrt{F_+^2 + F_\times^2}$", location="bottom")
    ax[idx_row,idx_col].grid(True)
    idx_col += 1

#Hiding axis on which there are no plot
if len(detectors)+2 < rows*cols:
    ax[-1,-1].axis("Off")
    if len(detectors)+2 < rows*cols - 1:
        ax[-1,-2].axis("Off")

# Wrap up
plot_caption = f'First panel: search sky grid points (in black) associated to the input external trigger skymap (color scale). Second panel: zoom on the external trigger sky region. Other panels: sky grid points over the corresponding antenna pattern of the interferometers used for the search. Number of sky grid points: {num_points}'

save_fig_with_metadata(fig, outfile, cmd=' '.join(sys.argv),
                       title=opts.plot_title, caption=plot_caption)
plt.close()
