# -*- coding: utf-8 -*-
# ========================================================================================================
# Jose A. Escartin
# 2019.06.03
#
# A python interactive window for use in the Molecfit Reflex workflow.
# This interactive window shows the results of the molecfit_correct recipe.   
# From this interactive Actor the the user can modify pipeline parameters and re-initiate the processing.
#
#
# Pipeline Parameters for inclusion:
#
# 1. Recipe parameters:
#
#    --SUPPRESS_EXTENSION                Suppress arbitrary filename extension : TRUE (apply) or FALSE (don't apply). [FALSE]
#    --USE_ONLY_INPUT_PRIMARY_DATA  In order to use only the primary extension of the input SCIENCE FITS file. [FALSE]
#    --USE_DATA_EXTENSION_AS_DFLUX       If you use only primary extension data, you can define other extension as DFLUX (error flux) [Default = 0, Not used].[0]
#    --USE_DATA_EXTENSION_AS_MASK        If you use only primary extension data, you can define other extension as MASK [Default = 0, Not used].[0]
#    --MAPPING_CORRECT                   Mapping 'SCIENCE' - 'TELLURIC_CORR' [string with ext_number comma separated (int)]. [NULL]
#    --WLC_REF                           If DATA, the output corrected spectrum will have the same wavelengths as the input. If MODEL, the wavelength of the output will
#                                        contain the correction found my molecfit. It has an effect only for input files in BINARY TABLES format.
#
#
# Images to plot (TAG's): 
#
#    * SCIENCE/SCIENCE_CALCTRANS         The original input data
#    * TELL_CORR                         Telluric correction
#    * SCIENCE_TELLURIC_CORR             For every input spectrum extension generate n-range extensions
#    * SPECTRUM_TELLURIC_CORR            Telluric correction applied  for each extension 
#
# ========================================================================================================

from __future__ import with_statement
from __future__ import absolute_import
from __future__ import print_function
import sys
import copy as mf_copy
from astropy.io import fits

try:
    import numpy
    import os
    import re
    import reflex
    from pipeline_product import PipelineProduct
    import pipeline_display
    import reflex_plot_widgets
    import matplotlib.gridspec as gridspec
    from matplotlib.text import Text
    from pylab import *
    try:
        from astropy.io import fits as pyfits
    except ImportError:
        import pyfits    
    import_success = True
except ImportError:
    import_success = False
    print("Error importing modules pyfits, wx, matplotlib, numpy")

from molecfit_common import *

#-------------------------------------------------------------------------------------
def mf_cp_hdu(hdu) :
    '''
    copy an hdu without keeping the file open
    '''
    if isinstance(hdu, fits.BinTableHDU) :
        return fits.BinTableHDU.from_columns(
            hdu.columns,
            header=hdu.header,
            nrows=hdu.header['NAXIS2'],
        )
    elif isinstance(hdu, (fits.ImageHDU, fits.PrimaryHDU)) :
        if hdu.data is None :
            return fits.PrimaryHDU()
        else :
            return fits.ImageHDU(
                data=np.array(hdu.data),
                header=hdu.header
            )
    else :
        return None
#-------------------------------------------------------------------------------------

def paragraph(text, width=None):
    """ wrap text string into paragraph
       text:  text to format, removes leading space and newlines
       width: if not None, wraps text, not recommended for tooltips as
              they are wrapped by wxWidgets by default
    """
    import textwrap
    if width is None:
        return textwrap.dedent(text).replace('\n', ' ').strip()
    else:
        return textwrap.fill(textwrap.dedent(text), width=width)


class DataPlotterManager(object):

    # static members
    recipe_name = "molecfit_correct"

    # Input recipes
    science_cat = [
        "SCIENCE",                                          # molecfit
        'REDUCED_IDP_SCI_LSS',                              # fors
        'REDUCED_IDP_SCI_MOS',                              # fors
        'REDUCED_IDP_SCI_MXU',                              # fors
        'REDUCED_IDP_STD_MOS',                              # fors
        'REDUCED_IDP_STD_MOS_SEDCORR',                      # fors
        'SCI_SLIT_FLUX_IDP_VIS','SCI_SLIT_FLUX_IDP_NIR',    # xshoo
        'TELL_SLIT_FLUX_IDP_VIS','TELL_SLIT_FLUX_IDP_NIR',  # xshoo
    ]
    science_caltrans_cat = ["SCIENCE_CALCTRANS",]
    telluric_corr_cat = [
        "TELLURIC_CORR",                                    # molecfit
        "SCI_LSS_TELLURIC_CORR",                            # fors
        #"SEDCORR_REDUCED_IDP_STD_MOS_SEDCORR",              # fors
        "SEDCORR_REDUCED_IDP_STD_MOS",                      # fors
        "REDUCED_IDP_STD_MOS",                              # fors
        "TELLURIC_CORR_VIS", "TELLURIC_CORR_NIR",           # xshoo
    ]

    # Output recipes
    science_telluric_cat = [
        "SCIENCE_TELLURIC_CORR",                            # molecfit
        "SCI_LSS_TELLURIC_CORR",                            # fors
        "SCI_MXU_TELLURIC_CORR",                            # fors
        "SCI_MOS_TELLURIC_CORR",                            # fors
        "STD_MOS_TELLURIC_CORR",                            # fors
    ]
    spec_telluric_cat = "SPECTRUM_TELLURIC_CORR"

    EXT_stat_color = {'Empty':'copper', 'Active':'summer'}


    def setWindowTitle(self):
        return self.recipe_name+"_interactive"

    def setInteractiveParameters(self):
        """
        This function specifies which are the parameters that should be presented
        in the window to be edited.  Note that the parameter has to also be in the
        in_sop port (otherwise it won't appear in the window). The descriptions are
        used to show a tooltip. They should match one to one with the parameter
        list.
        """
        self.orig_init_sop=init_sop_dict(interactive_app.inputs.in_sop)
        all_RecipeParameters=[

            reflex.RecipeParameter(recipe=self.recipe_name, displayName="THRESHOLD",
                                   group="Correction", description="Use this value when the transmission function is lower than the specified threshold."),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="WLC_REF",
                                   group="Wavelength", description="If MODEL, the wavelenght correction found by fitting telluric lines is applied. Default: DATA (correction not applied). It has effect only for inputs in binary table formatUse this value when the transmission function is lower than the specified threshold."),

            reflex.RecipeParameter(recipe=self.recipe_name, displayName="SUPPRESS_EXTENSION",
                                   group="Recipe", description="Suppress arbitrary filename extension : TRUE (apply) or FALSE (don't apply)."),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="USE_ONLY_INPUT_PRIMARY_DATA",
                                   group="Recipe", description="In order to use only the primary extension of the input SCIENCE FITS file."),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="USE_DATA_EXTENSION_AS_DFLUX",
                                   group="Recipe", description="If you use only primary extension data, you can define other extension as DFLUX (error flux) [Default = 0, Not used]."),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="USE_DATA_EXTENSION_AS_MASK",
                                   group="Recipe", description="If you use only primary extension data, you can define other extension as MASK [Default = 0, Not used]."),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="MAPPING_CORRECT",
                                   group="Recipe", description="Mapping 'SCIENCE' - 'TELLURIC_CORR' [string with ext_number comma separated (int)]."),
            # xsh_molecfit_correct
            #reflex.RecipeParameter(recipe=self.recipe_name, displayName="COLUMN_WAVE",
            #                       group="Recipe", description="In the case of fits binary science input: name of the column in the input that identifies the wavelength."),
            #reflex.RecipeParameter(recipe=self.recipe_name, displayName="COLUMN_FLUX",
            #                       group="Recipe", description="In the case of fits binary science input: name of the column in the input that identifies the flux."),
            #reflex.RecipeParameter(recipe=self.recipe_name, displayName="COLUMN_DFLUX",
            #                       group="Recipe", description="In the case of fits binary science input: name of the column in the input that identifies the flux errors."),
        ]
        recipe_RecipeParameters=[]
        for p in all_RecipeParameters :
            if p.displayName in self.orig_init_sop['by_name'].keys() :
                recipe_RecipeParameters+=[p,]
        return recipe_RecipeParameters

    def readFitsData(self, fitsFiles):
        """
        This function should be used to read and organize the raw fits files
        produced by the recipes.
        It receives as input a list of reflex.FitsFiles
        """
        try:
            from builtins import str
        except ImportError:
            from __builtin__ import str

        # Initialize
        self.files_orig = dict()
        self.files_input = dict()
        self.files_output_science_telluric = dict()
        self.files_output_spec_telluric = dict()
        self.input_errors = 0
        self.use_only_first_ext = False

        #tmp_file = open("/tmp/molecfit-wkf-examples/out_readFits.txt", "w")
        orig_files = dict()
        self.telluric_corr_files = dict()
        self.science_files = dict()

        # Loop on all FITS files_input
        print("Classifying %d files..." %(len(fitsFiles)))
        for f in fitsFiles:

            if f.category not in orig_files.keys() :
                orig_files[f.category] = dict()
            try:
                input_file = PipelineProduct(f)
                filename = os.path.basename(f.name)
            except: 
                self.input_errors += 1
                break

            # Loop on all FITS files_input searching the output_file
            if ('HIERARCH ESO DRS MOLECFIT PRO FILE' in input_file.all_hdu[0].header):
                orig_file_keyword = 'HIERARCH ESO DRS MOLECFIT PRO FILE'
            elif ('HIERARCH ESO DRS MF PRO FILE' in input_file.all_hdu[0].header):
                orig_file_keyword = 'HIERARCH ESO DRS MF PRO FILE'
            elif ('ARCFILE' in input_file.all_hdu[0].header):
                orig_file_keyword = 'ARCFILE'
            elif ('MJD-OBS' in input_file.all_hdu[0].header):
                orig_file_keyword = ['MJD-OBS','%13.10f'] # cpl MJD_OBS format...
            elif ('DATE-OBS' in input_file.all_hdu[0].header):
                orig_file_keyword = 'DATE-OBS'
            else: 
                orig_file_keyword = '_UNKNOWN_'

            if isinstance(orig_file_keyword,str) :
                orig_file_keyword=[orig_file_keyword,'%s']
            orig_file_val=orig_file_keyword[1] %(input_file.all_hdu[0].header.get(orig_file_keyword[0],f.name))
            orig_files[f.category][orig_file_val]=f

            if f.category in self.telluric_corr_cat :
                try:
                    p_telluric_corr = PipelineProduct(f)
                except: 
                    self.input_errors += 1
                    break
                telluric_corr_hdu = p_telluric_corr.all_hdu
                self.use_only_first_ext = telluric_corr_hdu[0].header.get('HIERARCH ESO DRS MOLECFIT PARAM USE_ONLY_INPUT_PRIMARY_DATA', telluric_corr_hdu[0].header.get('HIERARCH ESO DRS MF PARAM USE_ONLY_INPUT_PRIMARY_DATA'))
                column_lambda      = telluric_corr_hdu[0].header.get('HIERARCH ESO DRS MOLECFIT PARAM COLUMN_LAMBDA', telluric_corr_hdu[0].header.get('HIERARCH ESO DRS MF PARAM COLUMN_LAMBDA'))
                column_flux        = telluric_corr_hdu[0].header.get('HIERARCH ESO DRS MOLECFIT PARAM COLUMN_FLUX', telluric_corr_hdu[0].header.get('HIERARCH ESO DRS MF PARAM COLUMN_FLUX'))
                p_telluric_corr.all_hdu.close()
                self.telluric_corr_files[orig_file_val]=orig_files[f.category][orig_file_val]

            # For each sci_reconstructed image
            if f.category in self.science_cat or f.category in self.science_caltrans_cat :

                # Create a Dictionary per file
                self.files_orig[filename] = dict()
                self.files_input[filename] = dict()
                self.files_output_science_telluric[filename] = dict()
                self.files_output_spec_telluric[filename] = dict()

                self.files_orig[filename]["ORIG_KEYWORD"] = orig_file_keyword
                self.files_orig[filename]["ORIG_FILE"]    = orig_file_keyword[1] %(input_file.all_hdu[0].header.get(orig_file_keyword[0],f.name))
                self.science_files[orig_file_val]=orig_files[f.category][orig_file_val]
            input_file.all_hdu.close()

        print("Associating %d files..." %(len(self.science_files)))
        for k in self.science_files.keys() :

            f=self.science_files[k]
            filename=os.path.basename(f.name)
            input_file = PipelineProduct(f)

            filesFound = 0

            for sci_tell_cat in self.science_telluric_cat :
                if k in orig_files.get(sci_tell_cat,{}).keys() :
                    f2=orig_files[sci_tell_cat].get(k)
                    output_science_telluric_file = PipelineProduct(f2)
                    self.files_orig[filename]["SCIENCE_TELLURIC"] = os.path.basename(f2.name)
                    filesFound += 1
            '''
            if k in orig_files.get(self.spec_telluric_cat,{}).keys() :
                f2=orig_files[self.spec_telluric_cat].get(k)
                output_spec_telluric_file = PipelineProduct(f2)
                self.files_orig[filename]["SPEC_TELLURIC"] = os.path.basename(f2.name)
                filesFound += 1
            '''

            #tmp_file.write("\nfilename = "               + filename                                           + "\n");
            #tmp_file.write("orig_file_keyword: "         + str(self.files_orig[filename]["ORIG_KEYWORD"])     + "\n");
            #tmp_file.write("orig_file: "                 + str(self.files_orig[filename]["ORIG_FILE"])        + "\n");
            #tmp_file.write("filename_science_telluric: " + str(self.files_orig[filename]["SCIENCE_TELLURIC"]) + "\n");
            #tmp_file.write("filename_spec_telluric: "    + str(self.files_orig[filename]["SPEC_TELLURIC"])    + "\n");
            if self.input_errors > 0 or filesFound == 0  : 
                self.input_errors += 1
                break
            # Loop on extensions

            for ext_num in range(len(input_file.all_hdu)) :

                # Take extension in the output file
                input_ext=mf_cp_hdu(input_file.all_hdu[ext_num])

                if ( (self.use_only_first_ext == True and ext_num == 0) or (self.use_only_first_ext == False) ):

                    output_science_telluric_ext = mf_cp_hdu(output_science_telluric_file.all_hdu[ext_num])
                    #output_spec_telluric_ext = mf_cp_hdu(output_spec_telluric_file.all_hdu[ext_num])

                    if (self.use_only_first_ext and ext_num == 0) :
                        if ('XTENSION' in input_ext.header) :
                            xtension = input_ext.header['XTENSION']
                        else :
                            xtension = "IMAGE"
                    else :
                        # XTENSION is missing in the primary header - Skip it anyway
                        try:
                            xtension = input_ext.header['XTENSION']
                        except KeyError:
                            continue

                    # Create Entry for the extension
                    ext_num_str = str(ext_num)
                    self.files_input[filename][ext_num_str]=dict()
                    self.files_output_science_telluric[filename][ext_num_str]=dict()
                    self.files_output_spec_telluric[filename][ext_num_str]=dict()
                    naxis = input_ext.header['NAXIS']

                    # Get the extension number to get the EXT status
                    self.files_input[filename][ext_num_str]["NAXIS"] = naxis
                    self.files_input[filename][ext_num_str]["XTENSION"] = xtension

                    # Save output spectrum telluric data
                    #if output_spec_telluric_ext.header['NAXIS'] > 0 :
                    #    self.files_output_spec_telluric[filename][ext_num_str]["Spectrum"] = output_spec_telluric_ext.data

                    #self.files_input[filename][ext_num_str]["UNIT"] = r"$\mu$m" #input_ext.data.Columns(column_lambda).Unit 
                    self.files_input[filename][ext_num_str]["WAVE_UNIT"] = "N/A"
                    self.files_input[filename][ext_num_str]["FLUX_UNIT"] = "N/A"

                    if input_ext.header['NAXIS'] > 0 :
                        if (xtension == "BINTABLE"):
                            self.files_input[filename][ext_num_str]["EXT_STATUS"] = 'Active'


                            # TODO: CHECK column_lamba == NULL -> If yes, default 'lambda' and column_flux == NULL

                            naxis1 = input_ext.header['NAXIS1']
                            naxis2 = input_ext.header['NAXIS2']

                            if (naxis1 > 0 and naxis2 == 1) :

                                # 1-ROW ARRAY BINTABLE FORMAT : CONVERT TO NORMAL BINTABLE

                                # Extracted Spectrum
                                self.files_input[filename][ext_num_str]["Lambda"] = input_ext.data[column_lambda][0]
                                self.files_input[filename][ext_num_str]["Spectrum"] = input_ext.data[column_flux][0]
                                _c_units={}
                                for _c in input_ext.columns.columns :
                                    if hasattr(_c,'unit') :
                                        _c_units[_c.name]=_c.unit
                                self.files_input[filename][ext_num_str]["WAVE_UNIT"] = _c_units.get(column_lambda)
                                self.files_input[filename][ext_num_str]["FLUX_UNIT"] = _c_units.get(column_flux)
                                # Corrected Spectrum
                                self.files_output_science_telluric[filename][ext_num_str]["Lambda"] = output_science_telluric_ext.data[column_lambda][0]
                                self.files_output_science_telluric[filename][ext_num_str]["Spectrum"] = output_science_telluric_ext.data[column_flux][0]

                            else :

                                # N-ROW NORMAL BINTABLE FORMAT

                                # Extracted Spectrum
                                self.files_input[filename][ext_num_str]["Lambda"] = input_ext.data.field(column_lambda)
                                self.files_input[filename][ext_num_str]["Spectrum"] = input_ext.data.field(column_flux)
                                _c_units={}
                                for _c in input_ext.columns.columns :
                                    if hasattr(_c,'unit') :
                                        _c_units[_c.name]=_c.unit
                                self.files_input[filename][ext_num_str]["WAVE_UNIT"] = _c_units.get(column_lambda)
                                self.files_input[filename][ext_num_str]["FLUX_UNIT"] = _c_units.get(column_flux)
                                # Corrected Spectrum
                                self.files_output_science_telluric[filename][ext_num_str]["Lambda"] = output_science_telluric_ext.data.field(column_lambda)
                                self.files_output_science_telluric[filename][ext_num_str]["Spectrum"] = output_science_telluric_ext.data.field(column_flux)

                        elif (xtension == "IMAGE" and naxis == 1):
                            self.files_input[filename][ext_num_str]["EXT_STATUS"] = 'Active'

                            # Get Keyword infos  
                            self.files_input[filename][ext_num_str]["CRPIX1"] = input_ext.header['CRPIX1']
                            self.files_input[filename][ext_num_str]["CRVAL1"] = input_ext.header['CRVAL1']
                            if 'CD1_1' in input_ext.header:
                                self.files_input[filename][ext_num_str]["CD1_1"] = input_ext.header['CD1_1']
                            else: 
                                self.files_input[filename][ext_num_str]["CDELT1"] = input_ext.header['CDELT1']
                            self.files_input[filename][ext_num_str]["WAVE_UNIT"] = input_ext.header.get('CUNIT1',r"N/A")
                            self.files_input[filename][ext_num_str]["FLUX_UNIT"] = input_ext.header.get('BUNIT',r"N/A")

                            # Extracted Spectrum
                            self.files_input[filename][ext_num_str]["Spectrum"] = input_ext.data

                            # Corrected Spectrum
                            self.files_output_science_telluric[filename][ext_num_str]["Spectrum"] = output_science_telluric_ext.data

                        elif (xtension == "IMAGE" and naxis == 2):
                            self.files_input[filename][ext_num_str]["EXT_STATUS"] = 'Active'

                            # Get Keyword infos  
                            self.files_input[filename][ext_num_str]["CRPIX2"] = input_ext.header['CRPIX2']
                            self.files_input[filename][ext_num_str]["CRVAL2"] = input_ext.header['CRVAL2']
                            if 'CD2_2' in input_ext.header:
                                self.files_input[filename][ext_num_str]["CD2_2"] = input_ext.header['CD2_2']
                            else: 
                                self.files_input[filename][ext_num_str]["CDELT2"] = input_ext.header['CDELT2']
                            self.files_input[filename][ext_num_str]["WAVE_UNIT"] = input_ext.header.get('CUNIT2',r"N/A")
                            self.files_input[filename][ext_num_str]["FLUX_UNIT"] = input_ext.header.get('BUNIT',r"N/A")

                            # Fill Extracted Spectrum
                            self.files_input[filename][ext_num_str]["Spectrum"] = []
                            for plane in input_ext.data:
                                plane_nan_free = plane[~numpy.isnan(plane)]
                                if (len(plane_nan_free) > 0):
                                    mean = plane_nan_free.mean()
                                else:
                                    mean = numpy.nan
                                self.files_input[filename][ext_num_str]["Spectrum"].append(mean)

                            # Fill Corrected Spectrum
                            self.files_output_science_telluric[filename][ext_num_str]["Spectrum"] = []
                            for plane in output_science_telluric_ext.data:
                                plane_nan_free = plane[~numpy.isnan(plane)]
                                if (len(plane_nan_free) > 0):
                                    mean = plane_nan_free.mean()
                                else:
                                    mean = numpy.nan
                                self.files_output_science_telluric[filename][ext_num_str]["Spectrum"].append(mean)

                        elif (xtension == "IMAGE" and naxis == 3):
                            self.files_input[filename][ext_num_str]["EXT_STATUS"] = 'Active'

                            # Get Keyword infos  
                            self.files_input[filename][ext_num_str]["CRPIX3"] = input_ext.header['CRPIX3']
                            self.files_input[filename][ext_num_str]["CRVAL3"] = input_ext.header['CRVAL3']
                            if 'CD2_3' in input_ext.header:
                                self.files_input[filename][ext_num_str]["CD3_3"] = input_ext.header['CD3_3']
                            else: 
                                self.files_input[filename][ext_num_str]["CDELT3"] = input_ext.header['CDELT3']
                            self.files_input[filename][ext_num_str]["WAVE_UNIT"] = input_ext.header.get('CUNIT3',r"N/A")
                            self.files_input[filename][ext_num_str]["FLUX_UNIT"] = input_ext.header.get('BUNIT',r"N/A")

                            # Fill Extracted Spectrum
                            self.files_input[filename][ext_num_str]["Spectrum"] = []
                            for cube_plane in input_ext.data:
                                cube_plane_nan_free = cube_plane[~numpy.isnan(cube_plane)]
                                if (len(cube_plane_nan_free) > 0):
                                    mean = cube_plane_nan_free.mean()
                                else:
                                    mean = numpy.nan
                                self.files_input[filename][ext_num_str]["Spectrum"].append(mean)

                            # Fill Corrected Spectrum
                            self.files_output_science_telluric[filename][ext_num_str]["Spectrum"] = []
                            for cube_plane in output_science_telluric_ext.data:
                                cube_plane_nan_free = cube_plane[~numpy.isnan(cube_plane)]
                                if (len(cube_plane_nan_free) > 0):
                                    mean = cube_plane_nan_free.mean()
                                else:
                                    mean = numpy.nan
                                self.files_output_science_telluric[filename][ext_num_str]["Spectrum"].append(mean)

                        else:
                            self.files_input[filename][ext_num_str]["EXT_STATUS"] = 'Empty'
                            self.files_output_science_telluric[filename][ext_num_str]["EXT_STATUS"] = 'Empty'
                            self.files_output_spec_telluric[filename][ext_num_str]["EXT_STATUS"] = 'Empty'

        #tmp_file.close()

        # If proper files_input are there...
        if (len(self.files_input.keys()) > 0):
            # Set the plotting functions
            self._add_subplots = self._add_subplots
            self._plot = self._data_plot
        else:
            self._add_subplots = self._add_nodata_subplots
            self._plot = self._nodata_plot

    def addSubplots(self, figure):
        """
        This function should be used to setup the subplots of the gui.  The the
        matplotlib documentation for a description of subplots.
        """
        self._add_subplots(figure)

    def plotProductsGraphics(self):
        """
        This function should be used to plot the data onto the subplots.
        """
        self._plot()

    def plotWidgets(self) :
        widgets = list()

        if self.input_errors == 0: 
        # Files Selector radiobutton
            if len(self.files_input) > 1:
              self.radiobutton = reflex_plot_widgets.InteractiveRadioButtons(self.files_selector, self.setFSCallback, 
                                                                       sorted(self.files_input.keys()), 0,
                                                                       title='Input files selection (Left Mouse Button)')
              widgets.append(self.radiobutton)

              self.clickable_exts = reflex_plot_widgets.InteractiveClickableSubplot(self.exts_selector, self.setEXTSCallback)
              widgets.append(self.clickable_exts)

        return widgets

    def extension_has_spectrum(self, filename, ext_num_str):
        if ("Spectrum" in self.files_input[self.selected_file][ext_num_str].keys()):
            return True
        else:
            return False

    def setEXTSCallback(self, point) :
        if (1 < point.ydata < 3) :
            ext_num_str = str(int((point.xdata/2)+0.5))
            if (self.extension_has_spectrum(self.selected_file, ext_num_str)):
                # Update selected extension
                self.selected_extension = ext_num_str
                # Redraw EXTs selection
                self._plot_exts_selector(self.selected_file)
                # Redraw spectrum
                self._plot_spectrum()

    def setFSCallback(self, filename) :

        # Keep track of the selected file
        self.selected_file = filename

        # Check that the new file currently selected extension is valid
        if (not self.extension_has_spectrum(self.selected_file, self.selected_extension)):
            self.selected_extension = self._get_first_valid_ext(self.selected_file)
            self._plot_exts_selector(self.selected_file)

        # Redraw spectrum
        self._plot_spectrum()

    def _add_subplots(self, figure):
        ## Currently this handles single and multiple files...
        ## Needs to also handle multiple extensions...
        ## Would be (infinitely better) with pull-downs rather than
        ## these god-awful radio buttons...
        max_files_exts=len(self.files_input)
        for f in self.files_input.keys() :
            if len(self.files_input[f]) > max_files_exts :
                max_files_exts=len(self.files_input[f])
        if max_files_exts > 1 :
           gs = gridspec.GridSpec(6, 1)
           self.files_selector = figure.add_subplot(gs[0:2,:])
           self.exts_selector = figure.add_subplot(gs[2,:])
           self.spec_plot = figure.add_subplot(gs[3:6,:])
        else:
           gs = gridspec.GridSpec(1, 1)
           self.spec_plot = figure.add_subplot(gs[0,0])

    def _data_plot_get_tooltip(self):
        return self.selected_file+" ["+self.selected_extension+"];"

    def _plot_spectrum(self):

        extension_input = self.files_input[self.selected_file][self.selected_extension]
        extension_science_telluric = (
            self.files_output_science_telluric[self.selected_file][self.selected_extension]
        )

        # Plot Spectrum (to be contd.)
        self.spec_plot.clear()
        specdisp = pipeline_display.SpectrumDisplay()
        specdisp.setLabels(
            r"$\lambda$["
            +self._process_label(extension_input["WAVE_UNIT"] or 'N/A')
            +r"] (blue: extracted, red: corrected)", 
            r"Flux ["
            +self._process_label(extension_input["FLUX_UNIT"] or 'N/A')
            +r"]",
        )

        naxis = extension_input["NAXIS"]
        xtension = extension_input['XTENSION']

        if (xtension == "BINTABLE"):
            # Define wave
            wave = extension_input["Lambda"]

        elif (xtension == "IMAGE" and naxis == 1):
            # Define wave
            pix1d = numpy.arange(len(extension_input["Spectrum"]))
            if 'CD1_1' in extension_input:
                wave = extension_input["CRVAL1"] + pix1d * extension_input["CD1_1"]
            else: 
                wave = extension_input["CRVAL1"] + pix1d * extension_input["CDELT1"]

        elif (xtension == "IMAGE" and naxis == 2):
            # Define wave
            pix2d = numpy.arange(len(extension_input["Spectrum"]))
            if 'CD2_2' in extension_input:
                wave = extension_input["CRVAL2"] + pix2d * extension_input["CD2_2"]
            else: 
                wave = extension_input["CRVAL2"] + pix2d * extension_input["CDELT2"]

        elif (xtension == "IMAGE" and naxis == 3):
            # Define wave
            pix3d = numpy.arange(len(extension_input["Spectrum"]))
            if 'CD3_3' in extension_input:
                wave = extension_input["CRVAL3"] + pix3d * extension_input["CD3_3"]
            else: 
                wave = extension_input["CRVAL3"] + pix3d * extension_input["CDELT3"]

        else :
            # Define wave
            wave = extension_science_telluric["Lambda"],

        # Plot Spectrum
        specdisp.display(
            self.spec_plot,
            "Spectrum",
            self._data_plot_get_tooltip(),
            wave,
            extension_input["Spectrum"]
        )
        specdisp.overplot(
            self.spec_plot,
            wave,
            extension_science_telluric["Spectrum"],
            'red'
        )

    def _process_label(self, in_label):
        if in_label is None : return None
        pretty_units={
            "angstrom": "$\AA$",
            "micron": "$\mu$m",
            "adu": "ADU",
            "erg.s**(-1).cm**(-2).angstrom**(-1)": r"erg sec$^{-1}$cm$^{-2}$$\AA^{-1}$",
            'erg.cm**(-2).s**(-1).angstrom**(-1)': r"erg sec$^{-1}$cm$^{-2}$$\AA^{-1}$",
            "10**(-16)erg.cm**(-2).s**(-1).angstrom**(-1)": r"$10^{-16}$erg sec$^{-1}$cm$^{-2}$$\AA^{-1}$",
            "10^-16 erg/cm^2/angstrom/s": r"$10^{-16}$erg sec$^{-1}$cm$^{-2}$$\AA^{-1}$",
            "J*radian/m^3/s": r"J radian m$^{-3}s$^{-1}",
            "e-": r"e$^-$",
        }
        return pretty_units.get(in_label.lower(),in_label)

    def _get_ext_status_from_file(self, filename, ext_num_str):
        cur_status = self.files_input[filename][ext_num_str]["EXT_STATUS"]
        if cur_status == "Active" :
            return 'Active'
        if cur_status == "Empty" :
            return 'Empty'
        return "Unknown"

    def _plot_exts_selector(self, filename):
       if len(self.files_input) > 1: 
        self.exts_selector.clear()

        # Loop on the different kind of Status to Print the Legend
        self.exts_selector.imshow(numpy.zeros((1,1)), extent=(1, 9, 4, 6), cmap='summer')
        self.exts_selector.text(2, 4.5, 'Active', fontsize=11,color='white')
        self.exts_selector.imshow(numpy.zeros((1,1)), extent=(11, 19, 4, 6), cmap='copper')
        self.exts_selector.text(12, 4.5, 'Empty', fontsize=11,color='white')

        # Display the EXTs selection squares
        box_y_start = 1
        box_y_stop  = 3
        box_xwidth  = 1.5
        for ext_num_str in self.files_input[filename].keys():

            # Compute the EXT number
            ext_number = int(ext_num_str)

            ext_0 = 0
            if self.use_only_first_ext :
                ext_0 = 1

            # Draw the little EXT image
            box_xstart = 2 * (ext_number + ext_0) - 1
            box_xstop  = box_xstart + box_xwidth

            ext_status = self._get_ext_status_from_file(filename, ext_num_str)
            self.exts_selector.imshow(numpy.zeros((1,1)), extent=(box_xstart,box_xstop,box_y_start,box_y_stop), 
                    cmap=self.EXT_stat_color[ext_status])

            # Write the EXT number in the image
            self.exts_selector.text(2 * (ext_number + ext_0 - 1) + 1.2, 1.5, str(ext_number), fontsize=13,color='white')

            # Mark the selected EXT
            if (ext_num_str == self.selected_extension):
                self.exts_selector.imshow(numpy.zeros((1,1)), extent=(box_xstart,box_xstop,0.5,0.7), 
                        cmap=self.EXT_stat_color[ext_status])

        self.exts_selector.axis([0,50,0,7])
        self.exts_selector.set_title("Extension Selection (Mouse middle button)")
        self.exts_selector.get_xaxis().set_ticks([])
        self.exts_selector.get_yaxis().set_ticks([])

    # Get the first valid extension name (ie containing a spectrum) in a file - "" if None 
    def _get_first_valid_ext(self, filename):
        for ext_num_str in sorted(self.files_input[filename].keys()):
            if (self.extension_has_spectrum(filename, ext_num_str)) :
                return ext_num_str
        return ""

    def _data_plot(self):

        # Initial file is the first one
        try: 
            self.selected_file = self.files_input.keys()[0]
        except TypeError: 
            self.selected_file = list(self.files_input.keys())[0]
        self.selected_extension = self._get_first_valid_ext(self.selected_file)

        # Plot the EXTS selection
        self._plot_exts_selector(self.selected_file)

        # Draw Spectrum
        if self.input_errors == 0:
            self._plot_spectrum()
        else:
            self.spec_plot.set_title("INPUT DATA ERROR: Missing files ")


    def _add_nodata_subplots(self, figure):
        self.img_plot = figure.add_subplot(1,1,1)

    def _nodata_plot(self):
        # could be moved to reflex library?
        self.img_plot.set_axis_off()
        text_nodata = "INPUT ERROR: Data not found." # Input files_input should contain this" \
                       #" type:\n%s" % self.single_cubes
        self.img_plot.text(0.1, 0.6, text_nodata, color='#11557c',
                      fontsize=18, ha='left', va='center', alpha=1.0)
        self.img_plot.tooltip = 'No data found'

#This is the 'main' function
if __name__ == '__main__':
    from reflex_interactive_app import PipelineInteractiveApp

    # Create interactive application
    interactive_app = PipelineInteractiveApp(enable_init_sop=True)

    # get inputs from the command line
    interactive_app.parse_args()

    #Check if import failed or not
    if not import_success:
        interactive_app.setEnableGUI(False)

    #Open the interactive window if enabled
    if interactive_app.isGUIEnabled():
        #Get the specific functions for this window
        dataPlotManager = DataPlotterManager()
        dataPlotManager.recipe_name=interactive_app.inputs.in_sop[0].recipe

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    #Print outputs. This is parsed by the Reflex python actor to
    #get the results. Do not remove
    interactive_app.print_outputs()
    sys.exit()
