# ========================================================================================================
# Jose A. Escartin
# 2019.03.22
#
# A python interactive window for use in the KMOS Reflex workflow.
# This interactive window shows the results of the kmos_molecfit_correct recipe.   
# From this interactive Actor the the user can modify pipeline parameters and re-initiate the processing.
#
#
# Pipeline Parameters for inclusion:
#
#    --suppress_extension    : Suppress arbitrary filename extension.(TRUE (apply) or FALSE (don't apply). [FALSE]
#    --min_threshold         : Minimum threshold in the telluric correction. If min_threshold > 0 then elements in the telluric correction that are smaller than the threshold are set to the threshold. [0.01]#
#
# Images to plot (TAG's): 
#
#    * STAR_SPEC/EXTRACT_SPEC/SCIENCE/SCI_RECONSTRUCTED     The input data
#    * TELLURIC_CORR/RESPONSE                               kmos_molecfit_calctrans - Telluric/Response correction for each IFU. Output fits file with 48 ext. (24-data and 24-error, in IMAGE format).
#    * SINGLE_SPECTRA                                       The 1D input data corrected file.
#    * SINGLE_CUBES                                         The 3D input data corrected file.
#
# ========================================================================================================
    
from __future__ import with_statement
from __future__ import absolute_import
from __future__ import print_function
import sys

try:
    import numpy
    import os
    import re
    import reflex
    from pipeline_product import PipelineProduct
    import pipeline_display
    import reflex_plot_widgets
    import matplotlib.gridspec as gridspec
    from matplotlib.text import Text
    from pylab import *

    import_success = True
except ImportError:
    import_success = False
    print("Error importing modules pyfits, wx, matplotlib, numpy")

def paragraph(text, width=None):
    """ wrap text string into paragraph
       text:  text to format, removes leading space and newlines
       width: if not None, wraps text, not recommended for tooltips as
              they are wrapped by wxWidgets by default
    """
    import textwrap
    if width is None:
        return textwrap.dedent(text).replace('\n', ' ').strip()
    else:
        return textwrap.fill(textwrap.dedent(text), width=width)

class DataPlotterManager(object):
    # static members
    recipe_name = "kmos_molecfit_correct"
    sci_reconstructed_cat = "SCI_RECONSTRUCTED"
    single_cubes_cat = "SINGLE_CUBES"
    
    IFU_stat_color = {'Empty':'copper', 'Active':'summer'}

    def setWindowTitle(self):
        return self.recipe_name+"_interactive"

    def setInteractiveParameters(self):
        return [
            
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="suppress_extension",
                    group="Inputs", description="Suppress arbitrary filename extension.(TRUE (apply) or FALSE (don't apply)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="min_threshold",
                    group="Inputs", description="Minimum threshold in the telluric correction. If min_threshold > 0 then elements in the telluric correction that are smaller than the threshold are set to the threshold"),

        ]

    def readFitsData(self, fitsFiles):
        
        # Initialise
        self.files_input = dict()
        self.files_output = dict()

        # Loop on all FITS files_input 
        for f in fitsFiles:
            
            # For each sci_reconstructed image
            if f.category == self.sci_reconstructed_cat :
                
                input_file = PipelineProduct(f)
                filename = os.path.basename(f.name)
                
                # Create a Dictionary per file
                self.files_input[filename] = dict()
                self.files_output[filename] = dict()
                
                # Loop on all FITS files_input searching the output_file
                input_arcfile = input_file.all_hdu[0].header['ARCFILE']
                for f2 in fitsFiles:
                    # For each single_cubes image
                    if f2.category == self.single_cubes_cat :
                        single_cube_file = PipelineProduct(f2)
                        if (input_arcfile == single_cube_file.all_hdu[0].header['ARCFILE']):
                            output_file = PipelineProduct(f2)
                
                # Loop on extensions
                num_ext = -1
                for input_ext in input_file.all_hdu:

                    # Take extension in the output file
                    num_ext = num_ext + 1
                    output_ext = output_file.all_hdu[num_ext]

                    # EXTNAME is missing in the primary header - Skip it anyway
                    try:
                        extname = input_ext.header['EXTNAME']
                    except KeyError:
                        continue
                   
                    # Create Entry for the extension
                    self.files_input[filename][extname]=dict()
                    self.files_output[filename][extname]=dict()
                        
                    # Get the IFU number from extname to get the IFU status
                    m = re.search(r"\d+", extname)
                    ifu_number = m.group()
                    self.files_input[filename][extname]["IFU_NUMBER"] = int(ifu_number)

                    naxis = input_ext.header['NAXIS']
                    if (naxis == 3):
                        self.files_input[filename][extname]["IFU_STATUS"] = 'Active'
                        
                        # Get Keyword infos  
                        self.files_input[filename][extname]["CRPIX3"] = input_ext.header['CRPIX3']
                        self.files_input[filename][extname]["CRVAL3"] = input_ext.header['CRVAL3']
                        self.files_input[filename][extname]["CDELT3"] = input_ext.header['CDELT3']
                        self.files_input[filename][extname]["UNIT"] = input_ext.header['ESO QC CUBE_UNIT']
                        self.files_input[filename][extname]["NAME"] = input_ext.header['ESO OCS ARM' + ifu_number +' NAME']
                        
                        # Fill Extracted Spectrum
                        self.files_input[filename][extname]["Spectrum"] = []
                        for cube_plane in input_ext.data:
                            cube_plane_nan_free = cube_plane[~numpy.isnan(cube_plane)]
                            if (len(cube_plane_nan_free) > 0):
                                mean = cube_plane_nan_free.mean()
                            else:
                                mean = numpy.nan
                            self.files_input[filename][extname]["Spectrum"].append(mean)

                        # Fill Corrected Spectrum
                        self.files_output[filename][extname]["Spectrum"] = []
                        for cube_plane in output_ext.data:
                            cube_plane_nan_free = cube_plane[~numpy.isnan(cube_plane)]
                            if (len(cube_plane_nan_free) > 0):
                                mean = cube_plane_nan_free.mean()
                            else:
                                mean = numpy.nan
                            self.files_output[filename][extname]["Spectrum"].append(mean)
                            
                    else:
                        self.files_input[filename][extname]["IFU_STATUS"] = 'Empty'
                        self.files_output[filename][extname]["IFU_STATUS"] = 'Empty'

        # If proper files_input are there...
        if (len(list(self.files_input.keys())) > 0):
            # Set the plotting functions
            self._add_subplots = self._add_subplots
            self._plot = self._data_plot
        else:
            self._add_subplots = self._add_nodata_subplots
            self._plot = self._nodata_plot

    def addSubplots(self, figure):
        self._add_subplots(figure)

    def plotProductsGraphics(self):
        self._plot()

    def plotWidgets(self) :
        widgets = list()

        # Files Selector radiobutton
        self.radiobutton = reflex_plot_widgets.InteractiveRadioButtons(self.files_selector, self.setFSCallback, sorted(list(self.files_input.keys())), 0, 
                title='Input files selection (Left Mouse Button)')
        widgets.append(self.radiobutton)
        
        self.clickable_ifus =reflex_plot_widgets.InteractiveClickableSubplot(self.ifus_selector, self.setIFUSCallback)
        widgets.append(self.clickable_ifus)

        return widgets

    def extension_has_spectrum(self, filename, extname):
        if ("Spectrum" in list(self.files_input[self.selected_file][extname].keys())):
            return True
        else:
            return False

    def setIFUSCallback(self, point) :
        if (1 < point.ydata < 3) :
            extname = "IFU."+str(int((point.xdata/2)+0.5))+".DATA"
            if (self.extension_has_spectrum(self.selected_file, extname)):
                # Update selected extension
                self.selected_extension = extname
                # Redraw IFUs selection
                self._plot_ifus_selector(self.selected_file)
                # Redraw spectrum
                self._plot_spectrum()

    def setFSCallback(self, filename) :
        # Keep track of the selected file
        self.selected_file = filename

        # Check that the new file currently selected extension is valid
        if (not self.extension_has_spectrum(self.selected_file, self.selected_extension)):
            self.selected_extension = self._get_first_valid_extname(self.selected_file)
            self._plot_ifus_selector(self.selected_file)
        # Redraw spectrum
        self._plot_spectrum()
        
    def _add_subplots(self, figure):
        gs = gridspec.GridSpec(6, 1)
        self.files_selector = figure.add_subplot(gs[0:2,:])
        self.ifus_selector = figure.add_subplot(gs[2,:])
        self.spec_plot = figure.add_subplot(gs[3:6,:])

    def _data_plot_get_tooltip(self):
        return self.selected_file+" ["+self.selected_extension+"];" +" Object NAME = "+str(  self.files_input[self.selected_file][self.selected_extension]["NAME"]  )

    def _plot_spectrum(self):
        extension_input = self.files_input[self.selected_file][self.selected_extension] 
        extension_output = self.files_output[self.selected_file][self.selected_extension] 
        
        # Plot Spectrum (to be contd.)
        self.spec_plot.clear()
        specdisp = pipeline_display.SpectrumDisplay()
        specdisp.setLabels(r"$\lambda$[" + "$\mu$m]"  + " (blue: extracted, red: corrected)", 
                            self._process_label(extension_input["UNIT"]) )
        
        # Define wave
        pix = numpy.arange(len(extension_input["Spectrum"]))
        wave = extension_input["CRVAL3"] + pix * extension_input["CDELT3"]

        # Plot Spectrum (contd.)
        specdisp.display(self.spec_plot, "Spectrum", self._data_plot_get_tooltip(), wave, extension_input["Spectrum"])
        specdisp.overplot(self.spec_plot, wave, extension_output["Spectrum"], 'red')

    def _process_label(self, in_label):
        # If known, 'pretty print' the label
        if (in_label == "erg.s**(-1).cm**(-2).angstrom**(-1)"):
            return "Flux [erg sec" + r"$^{-1}$"+"cm" + r"$^{-2}$" + r"$\AA^{-1}$]"
        else:
            return in_label

    def _get_ifu_status_from_file(self, filename, extname):
        cur_status = self.files_input[filename][extname]["IFU_STATUS"]
        if cur_status == "Active" :
            return 'Active'
        if cur_status == "Empty" :
            return 'Empty'
        return "Unknown"

    def _plot_ifus_selector(self, filename):
        self.ifus_selector.clear()

        # Loop on the different kind of Status to Print the Legend
        self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(1, 9, 4, 6), cmap='summer')
        self.ifus_selector.text(2, 4.5, 'Active', fontsize=11,color='white')
        self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(11, 19, 4, 6), cmap='copper')
        self.ifus_selector.text(12, 4.5, 'Empty', fontsize=11,color='white')
                
        # Display the IFUs selection squares
        box_y_start = 1
        box_y_stop  = 3
        box_xwidth  = 1.5
        for extname in list(self.files_input[filename].keys()):
            # Compute the IFU number
            ifu_number  = self.files_input[filename][extname]['IFU_NUMBER']
            # Draw the little IFU image
            box_xstart  = 2 * ifu_number - 1
            box_xstop = box_xstart + box_xwidth

            ifu_status = self._get_ifu_status_from_file(filename, extname)
            self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(box_xstart,box_xstop,box_y_start,box_y_stop), 
                    cmap=self.IFU_stat_color[ifu_status])
            # Write the IFU number in the image
            self.ifus_selector.text(2 * (ifu_number-1) + 1.2, 1.5, str(ifu_number), fontsize=13,color='white')
            # Mark the selected IFU
            if (extname == self.selected_extension):
                self.ifus_selector.imshow(numpy.zeros((1,1)), extent=(box_xstart,box_xstop,0.5,0.7), 
                        cmap=self.IFU_stat_color[ifu_status])

        self.ifus_selector.axis([0,50,0,7])
        self.ifus_selector.set_title("IFU Selection (Mouse middle button)")
        self.ifus_selector.get_xaxis().set_ticks([])
        self.ifus_selector.get_yaxis().set_ticks([])

    # Get the first valid extension name (ie containing a spectrum) in a file - "" if None 
    def _get_first_valid_extname(self, filename):
        for extname in sorted(list(self.files_input[filename].keys())):
            if (self.extension_has_spectrum(filename, extname)) :
                return extname
        return ""

    def _data_plot(self):
        # Initial file is the first one
        self.selected_file = list(self.files_input.keys())[0]
        self.selected_extension = self._get_first_valid_extname(self.selected_file)
        
        # Plot the IFUS selection
        self._plot_ifus_selector(self.selected_file)

        # Draw Spectrum
        self._plot_spectrum()

    def _add_nodata_subplots(self, figure):
        gs = gridspec.GridSpec(6, 1)
        self.files_selector = figure.add_subplot(gs[0:2,:])
        self.ifus_selector = figure.add_subplot(gs[2,:])
        self.spec_plot = figure.add_subplot(gs[3:6,:])

    def _nodata_plot(self):
        # could be moved to reflex library?
        self.spec_plot.set_axis_off()
        text_nodata = "Data not found.\nMissing type:\n%s" % self.sci_reconstructed_cat
        self.spec_plot.text(0.1, 0.6, text_nodata, color='#11557c',
                      fontsize=18, ha='left', va='center', alpha=1.0)
        self.spec_plot.tooltip = 'No data found'

#This is the 'main' function
if __name__ == '__main__':
    from reflex_interactive_app import PipelineInteractiveApp

    # Create interactive application
    interactive_app = PipelineInteractiveApp(enable_init_sop=True)

    # get inputs from the command line
    interactive_app.parse_args()

    #Check if import failed or not
    if not import_success:
        interactive_app.setEnableGUI(False)

    #Open the interactive window if enabled
    if interactive_app.isGUIEnabled():
        #Get the specific functions for this window
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    #Print outputs. This is parsed by the Reflex python actor to
    #get the results. Do not remove
    interactive_app.print_outputs()
    sys.exit()
