# ========================================================================================================
# Jose A. Escartin
# 2019.04.23
#
# A python interactive window for use in the KMOS Reflex workflow.
# This interactive window shows the results of the kmos_level_correct recipe.   
# From this interactive Actor the the user can modify pipeline parameters and re-initiate the processing.
#
#
# Pipeline Parameters for inclusion:
#
#    --lcmethod            : Method to use for the level correction ["OSCAN" (overscan), "SLICES_MEAN" (intra slices with average), "SLICES_MEDIAN" (intra slices with median)]. [OSCAN]
#
#
# Images with possible plots (TAG's): 
#
#    * LEVEL_CORRECTED     LEVEL_CORRECTED
#
#
# ========================================================================================================

from __future__ import with_statement
from __future__ import absolute_import
from __future__ import print_function
import sys

try:
    import numpy
    import os
    import re
    import reflex
    from pipeline_product import PipelineProduct
    import pipeline_display
    import reflex_plot_widgets
    import matplotlib.gridspec as gridspec
    from matplotlib.text import Text
    from pylab import *

    import_success = True
except ImportError:
    import_success = False
    print("Error importing modules pyfits, wx, matplotlib, numpy")

def paragraph(text, width=None):
    """ wrap text string into paragraph
       text:  text to format, removes leading space and newlines
       width: if not None, wraps text, not recommended for tooltips as
              they are wrapped by wxWidgets by default
    """
    import textwrap
    if width is None:
        return textwrap.dedent(text).replace('\n', ' ').strip()
    else:
        return textwrap.fill(textwrap.dedent(text), width=width)

class DataPlotterManager(object):
    # static members
    recipe_name = "kmos_level_correct"
    level_corrected_cat = "LEVEL_CORRECTED"
    
    extension_stat_color = {'Empty':'copper', 'Active':'summer'}

    def setWindowTitle(self):
        return self.recipe_name+"_interactive"

    def setInteractiveParameters(self):
        return [
            
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="lcmethod",
                    group="Inputs", description="Method to use for the level correction [\"OSCAN\" (overscan), \"SLICES_MEAN\" (intra slices with average), \"SLICES_MEDIAN\" (intra slices with median)]"),

        ]

    def readFitsData(self, fitsFiles):
        
        # Initialise
        self.lc_files = dict()

        # Loop on all FITS files 
        for f in fitsFiles:
            
            # For each sci_reconstructed image
            if f.category == self.level_corrected_cat :
                
                lc_file = PipelineProduct(f)
                filename = os.path.basename(f.name)
                
                # Create a Dictionary per file
                self.lc_files[filename] = dict()
                
                # Take arcfile
                primary_header = lc_file.all_hdu[0].header 
                arcfile = primary_header['ARCFILE']
                
                # Loop on extensions
                for lc_ext in lc_file.all_hdu:
    
                    # EXTNAME is missing in the primary header - Skip it anyway
                    try:
                        extname = lc_ext.header['EXTNAME']
                    except KeyError:
                        continue
                   
                    # Create Entry for the extension
                    self.lc_files[filename][extname]=dict()
                        
                    # Get the EXT number from extname to get the EXT status
                    m = re.search(r"\d+", extname)
                    ext_number = m.group()
                    self.lc_files[filename][extname]["EXT_NUMBER"] = int(ext_number)

                    naxis = lc_ext.header['NAXIS']
                    if (naxis == 2):
                        self.lc_files[filename][extname]["EXT_STATUS"] = 'Active'
                        
                        # Get Keyword infos
                        self.lc_files[filename][extname]["NAME"] = arcfile
                        
                        # Fill image data
                        self.lc_files[filename][extname]["IMAGE"]       = lc_ext.data
                        self.lc_files[filename][extname]["IMAGE_AVG"]   = numpy.average(lc_ext.data)
                        self.lc_files[filename][extname]["IMAGE_STDEV"] = numpy.std(lc_ext.data)
                            
                    else:
                        self.lc_files[filename][extname]["EXT_STATUS"] = 'Empty'

        # If proper files are there...
        if (len(self.lc_files.keys()) > 0):
            # Set the plotting functions
            self._add_subplots = self._add_subplots
            self._plot = self._data_plot
        else:
            self._add_subplots = self._add_nodata_subplots
            self._plot = self._nodata_plot

    def addSubplots(self, figure):
        self._add_subplots(figure)

    def plotProductsGraphics(self):
        self._plot()

    def plotWidgets(self) :
        widgets = list()

        # Files Selector radiobutton
        self.radiobutton = reflex_plot_widgets.InteractiveRadioButtons(self.lc_files_selector, self.set_FS_Callback, sorted(self.lc_files.keys()), 0, 
                title='Input files selection (Left Mouse Button)')
        widgets.append(self.radiobutton)
        
        self.clickable_exts = reflex_plot_widgets.InteractiveClickableSubplot(self.ext_selector, self.set_EXTS_Callback)
        widgets.append(self.clickable_exts)

        return widgets

    def extension_has_image(self, filename, extname):
        if ("IMAGE" in self.lc_files[self.selected_file][extname].keys()):
            return True
        else:
            return False

    def set_EXTS_Callback(self, point) :

        if (1 < point.ydata < 3) :
            
            str_num = str(int((point.xdata/2)+0.5))
            extname = "CHIP" + str_num + ".INT1"

            if (self.extension_has_image(self.selected_file, extname)):
                # Update selected extension
                self.selected_extension = extname
                # Redraw EXTs selection
                self._plot_exts_selector(self.selected_file)
                # Redraw spectrum
                self._plot_image()
                

    def set_FS_Callback(self, filename) :
        # Keep track of the selected file
        self.selected_file = filename

        # Check that the new file currently selected extension is valid
        if (not self.extension_has_image(self.selected_file, self.selected_extension)):
            self.selected_extension = self._get_first_valid_extname(self.selected_file)
            self._plot_exts_selector(self.selected_file)
        # Redraw spectrum
        self._plot_image()
        
    def _add_subplots(self, figure):
        gs = gridspec.GridSpec(6, 1)
        self.lc_files_selector = figure.add_subplot(gs[0:2,:])
        self.ext_selector = figure.add_subplot(gs[2,:])
        self.img_plot = figure.add_subplot(gs[3:6,:])

    def _data_plot_get_tooltip(self):
        return self.selected_file+" ["+self.selected_extension+"];" +" Object NAME = "+str(  self.lc_files[self.selected_file][self.selected_extension]["NAME"]  )

    def _plot_image(self):
        
        # Clean previous
        self.img_plot.clear()
        imgdisp = pipeline_display.ImageDisplay()
        imgdisp.setAspect('equal')
        
        # Plot Image extension
        extension_input = self.lc_files[self.selected_file][self.selected_extension] 
        if 'IMAGE' in extension_input.keys():
            imgdisp.z_lim = (extension_input["IMAGE_AVG"] - extension_input["IMAGE_STDEV"], extension_input["IMAGE_AVG"] + 2 * extension_input["IMAGE_STDEV"])
            imgdisp.display(self.img_plot, extension_input["NAME"], self._data_plot_get_tooltip(), extension_input["IMAGE"])
            self.img_plot.set_xlabel("pixels")
            self.img_plot.set_ylabel("pixels")
        

    def _get_ext_status_from_file(self, filename, extname):
        cur_status = self.lc_files[filename][extname]["EXT_STATUS"]
        if cur_status == "Active" :
            return 'Active'
        if cur_status == "Empty" :
            return 'Empty'
        return "Unknown"

    def _plot_exts_selector(self, filename):
        self.ext_selector.clear()

        # Loop on the different kind of Status to Print the Legend
        self.ext_selector.imshow(numpy.zeros((1,1)), extent=(1, 9, 4, 6), cmap='summer')
        self.ext_selector.text(2, 4.5, 'Active', fontsize=11,color='white')
        self.ext_selector.imshow(numpy.zeros((1,1)), extent=(11, 19, 4, 6), cmap='copper')
        self.ext_selector.text(12, 4.5, 'Empty', fontsize=11,color='white')
                
        # Display the EXTs selection squares
        box_y_start = 1
        box_y_stop  = 3
        box_xwidth  = 1.5
        for extname in self.lc_files[filename].keys():
            # Compute the EXT number
            ext_number  = self.lc_files[filename][extname]['EXT_NUMBER']
            # Draw the little EXT image
            box_xstart  = 2 * ext_number - 1
            box_xstop = box_xstart + box_xwidth

            ext_status = self._get_ext_status_from_file(filename, extname)
            self.ext_selector.imshow(numpy.zeros((1,1)), extent=(box_xstart,box_xstop,box_y_start,box_y_stop), 
                    cmap=self.extension_stat_color[ext_status])
            # Write the EXT number in the image
            self.ext_selector.text(2 * (ext_number-1) + 1.2, 1.5, str(ext_number), fontsize=13,color='white')
            # Mark the selected EXT
            if (extname == self.selected_extension):
                self.ext_selector.imshow(numpy.zeros((1,1)), extent=(box_xstart,box_xstop,0.5,0.7), 
                        cmap=self.extension_stat_color[ext_status])

        self.ext_selector.axis([0,50,0,7])
        self.ext_selector.set_title("CHIP Selection (Mouse middle button)")
        self.ext_selector.get_xaxis().set_ticks([])
        self.ext_selector.get_yaxis().set_ticks([])

    # Get the first valid extension name (ie containing a spectrum) in a file - "" if None 
    def _get_first_valid_extname(self, filename):
        for extname in sorted(self.lc_files[filename].keys()):
            if (self.extension_has_image(filename, extname)) :
                return extname
        return ""
        
    def _data_plot(self):
        # Initial file is the first one
        self.selected_file = list(self.lc_files.keys())[0]
        self.selected_extension = self._get_first_valid_extname(self.selected_file)
        
        # Plot the EXTs selection
        self._plot_exts_selector(self.selected_file)

        # Draw Spectrum
        self._plot_image()

    def _add_nodata_subplots(self, figure):
        self.img_plot = figure.add_subplot(1,1,1)

    def _nodata_plot(self):
        # could be moved to reflex library?
        self.img_plot.set_axis_off()
        text_nodata = "Data not found. Input files_input should contain this" \
                       " type:\n%s" % self.single_cubes
        self.img_plot.text(0.1, 0.6, text_nodata, color='#11557c',
                      fontsize=18, ha='left', va='center', alpha=1.0)
        self.img_plot.tooltip = 'No data found'

#This is the 'main' function
if __name__ == '__main__':
    from reflex_interactive_app import PipelineInteractiveApp

    # Create interactive application
    interactive_app = PipelineInteractiveApp(enable_init_sop=True)

    # get inputs from the command line
    interactive_app.parse_args()

    #Check if import failed or not
    if not import_success:
        interactive_app.setEnableGUI(False)

    #Open the interactive window if enabled
    if interactive_app.isGUIEnabled():
        #Get the specific functions for this window
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    #Print outputs. This is parsed by the Reflex python actor to
    #get the results. Do not remove
    interactive_app.print_outputs()
    sys.exit()
