from __future__ import with_statement
from __future__ import absolute_import
from __future__ import print_function
import sys
import os.path

try:
    import numpy
    import reflex
    from pipeline_product import PipelineProduct
    import pipeline_display
    import reflex_plot_widgets

    from matplotlib import gridspec, pylab, pyplot, transforms
    import pdb  # for debugging
    from collections import defaultdict  # to make dictionary of lists

    import_success = True

except ImportError:
    import_success = False
    print("Error importing modules pyfits, wx, matplotlib, numpy")


# Median absolute deviation function; used to scale the images
def MAD(x):
    x = numpy.array(x)
    return numpy.median(numpy.abs(x - numpy.median(x)))


def paragraph(text, width=None):
    """ wrap text string into paragraph
       text:  text to format, removes leading space and newlines
       width: if not None, wraps text, not recommended for tooltips as
              they are wrapped by wxWidgets by default
    """
    import textwrap

    if width is None:
        return textwrap.dedent(text).replace("\n", " ").strip()
    else:
        return textwrap.fill(textwrap.dedent(text), width=width)


class DataPlotterManager(object):
    """
    This class must be added to the PipelineInteractiveApp with setPlotManager
    It must have following member functions which will be called by the app:
     - setInteractiveParameters(self)
     - readFitsData(self, fitsFiles):
     - addSubplots(self, figure):
     - plotProductsGraphics(self, figure, canvas)
    Following members are optional:
     - setWindowHelp(self)
     - setWindowTitle(self)
    """

    # static members

    recipe_name = "eris_nix_lss_skysub"
    img_cat_list = [
        "SKYSUB_OBJECT_LSS_JITTER",
        "SKYSUB_STD_LSS_JITTER",
    ]
    nameroot = "skysub"

    def setWindowTitle(self):
        return self.recipe_name + "_interactive"

    def setInteractiveParameters(self):
        """
        This function specifies which are the parameters that should be presented
        in the window to be edited.  Note that the parameter has to also be in the
        in_sop port (otherwise it won't appear in the window). The descriptions are
        used to show a tooltip. They should match one to one with the parameter
        list.
        """

        # Only selected Recipe parameters are shown because list is too long
        # The selection shown ones likely to be wanted to be changed by user
        return [
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.obj.min-pixels",
                group="catalogue",
                description="Minimum pixel area for each detected " "object.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.obj.threshold",
                group="catalogue",
                description="Detection threshold in sigma above sky.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.bkg.mesh-size",
                group="catalogue",
                description="Background smoothing box size.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.bkg.smooth-gauss-fwhm",
                group="catalogue",
                description="The FWHM of the Gaussian kernel used in "
                            "convolution for object detection.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="sky-source",
                group="sky",
                description="data to be used for "
                            "calculation of sky background. "
                            "<target | offset>",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="sky-selector",
                group="sky",
                description="method for selecting sky frames. " "<bracket>",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="sky-bracket-time",
                group="sky",
                description="selector bracket width. " "(seconds)",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="sky-method",
                group="sky",
                description="method for combining sky frames. "
                            "<collapse-median | median-median>",
            ),
        ]

    def readFitsData(self, fitsFiles):
        """
        This function should be used to read and organize the raw fits files
        produced by the recipes.
        It receives as input a list of reflex.FitsFiles
        """

        # frames is a dict of keyword/list pairs where elements of list are
        # PipelineProducts. For each category the list contains all matching FITS
        # files in the input parameter list

        self.frames = defaultdict(list)

        # assumes all input frmaes have same category
        for f in fitsFiles:
            self.frames[f.category].append(PipelineProduct(f))

        self.img_cat = None
        for cat in self.img_cat_list:
            if cat in self.frames:
                self.img_cat = cat
                break

        self.img_found = False
        if (len(self.frames[self.img_cat])) > 0:
            self.img_found = True
            self.n_frames = len(self.frames[self.img_cat])
            self.cur_frame = 0
            self.n_extn = (
                    len(self.frames[self.img_cat][0].hdulist()) - 1
            )  # number of extensions, assumed to be same for all fitsFiles

            # Don't read the individual calibrated science images in here
            #   for memory/performance reasons. Read them in below as needed.
            #   There could be a lot of frames and user may not want to see all of them.

            # Sort the frames using PIPEFILE keyword
            self.frames[self.img_cat].sort(
                key=lambda foo: foo.all_hdu[0].header["PIPEFILE"]
            )

            # re-define eso-rex's pipeline_display plotting functions to enable callbacks
            self._add_subplots = self._add_subplots
            self._plot = self._data_plot

            # Define radio button options
            self.left_opts = ["Click to\ngo back"]
            self.right_opts = ["Click to\nadvance"]
        else:
            # Set the plotting functions to NODATA ones
            self._add_subplots = self._add_nodata_subplots
            self._plot = self._nodata_plot

    def addSubplots(self, figure):
        self._add_subplots(figure)

    def plotProductsGraphics(self):
        self._plot()

    def plotWidgets(self):
        widgets = list()

        # Radio buttons
        # Only show them if at least one frame was found
        if self.img_found is True:

            self.radiobutton_left = reflex_plot_widgets.InteractiveRadioButtons(
                self.axradiobutton_left,
                self.setRadioCallback_left,
                self.left_opts,
                0,
                title="",
            )
            widgets.append(self.radiobutton_left)

            self.radiobutton_right = reflex_plot_widgets.InteractiveRadioButtons(
                self.axradiobutton_right,
                self.setRadioCallback_right,
                self.right_opts,
                0,
                title="",
            )
            widgets.append(self.radiobutton_right)

            # Adjust size of button boxes and font size of labels
            for i in range(len(widgets)):
                pos = widgets[i].rbuttons.ax.get_position()
                widgets[i].rbuttons.ax.set_position(
                    transforms.Bbox([[pos.x0, pos.y0 - 0.01], [pos.x1, 0.97]])
                )
                for j in range(len(widgets[i].rbuttons.labels)):
                    widgets[i].rbuttons.labels[j].set_fontsize(11)

        return widgets

    def setRadioCallback_left(self, label):

        # decrement (or wrap) frame number
        self.cur_frame -= 1
        if self.cur_frame < 0:
            self.cur_frame = self.n_frames - 1

        self._plot()

    def setRadioCallback_right(self, label):

        # advance (or wrap) frame number
        self.cur_frame += 1
        if self.cur_frame == (self.n_frames):
            self.cur_frame = 0

        self._plot()

    def _add_subplots(self, figure):

        self.img_plot = []
        if self.img_found is True:  # at least one frame found
            gs = gridspec.GridSpec(nrows=9, ncols=4)
            gs.update(hspace=0.7)  # make space so axis labels dont overlap

            # buttons
            self.axradiobutton_left = figure.add_subplot(gs[0, 0])
            self.axradiobutton_right = figure.add_subplot(gs[0, 1])

            # filename
            self.filename_plot = figure.add_subplot(gs[0, 2:3])

            # don't blot subplot border around the buttons or filename
            for spine in self.axradiobutton_left.spines.values():
                spine.set_visible(False)
            for spine in self.axradiobutton_right.spines.values():
                spine.set_visible(False)
            for spine in self.filename_plot.spines.values():
                spine.set_visible(False)

            # plots
            self.img_plot.append(figure.add_subplot(gs[1:5, 0:2]))
            self.img_plot.append(figure.add_subplot(gs[1:5, 2:4]))
            self.img_plot.append(figure.add_subplot(gs[5:9, 2:4]))
            self.img_plot.append(figure.add_subplot(gs[5:9, 0:2]))

            # Move ticks to rhs for readability
            self.img_plot[1].yaxis.tick_right()
            self.img_plot[2].yaxis.tick_right()

        else:
            gs = gridspec.GridSpec(2, 2)
            self.img_plot.append(figure.add_subplot(gs[0, 0]))
            self.img_plot.append(figure.add_subplot(gs[0, 1]))
            self.img_plot.append(figure.add_subplot(gs[1, 0]))
            self.img_plot.append(figure.add_subplot(gs[1, 1]))

    def _data_plot(self):

        temp = self.frames[self.img_cat][self.cur_frame]
        self.filename_plot.clear()
        self.filename_plot.set_axis_off()
        self.filename_plot.text(-0.5, 1.5, os.path.basename(temp.fits_file.name))

        # first look for data quality DQ - necessary for sensible
        # autoscaling because so much of the image is unused

        data_qual = None
        for iext in range(self.n_extn):
            ext_name = (
                self.frames[self.img_cat][self.cur_frame]
                    .all_hdu[iext + 1]
                    .header["EXTNAME"]
            )
            if ext_name == "DQ":
                temp = self.frames[self.img_cat][self.cur_frame]
                temp.readImage(iext + 1)
                data_qual = temp.image
            else:
                continue

        for iext in range(self.n_extn):

            ext_name = (
                self.frames[self.img_cat][self.cur_frame]
                    .all_hdu[iext + 1]
                    .header["EXTNAME"]
            )

            if ext_name == "DATA":
                i = 0
            elif ext_name == "ERR":
                i = 1
            elif ext_name == "CONFIDENCE":
                i = 2
            elif ext_name == "BKG_DATA":
                i = 3
            else:
                continue

            # clear image frame and make it visible
            self.img_plot[i].cla()
            self.img_plot[i].tooltip = ""
            self.img_plot[i].set_visible(True)
            for j in range(len(self.img_plot)):
                pylab.setp(self.img_plot[j].get_yticklabels(), visible=True)
                if j == 0 or j == 1:
                    self.img_plot[j].set_xlabel(" ")
                    pylab.setp(self.img_plot[j].get_xticklabels(), visible=False)
                if j == 1 or j == 2:
                    self.img_plot[j].set_ylabel(" ")
                    pylab.setp(self.img_plot[j].get_yticklabels(), visible=False)

            # Setup the selected image and display it
            imgdisp = pipeline_display.ImageDisplay()
            imgdisp.setAspect("equal")
            imgdisp.setLabels("X", "Y")

            title = "{} {}/{} {}".format(
                self.nameroot, self.cur_frame + 1, self.n_frames, ext_name
            )
            # Try reading image
            try:
                temp = self.frames[self.img_cat][self.cur_frame]
                temp.readImage(iext + 1)

                if data_qual is not None:
                    # blank empty parts of image so that autoscaling can work
                    temp.image[data_qual == 1] = numpy.nan

                imgdisp.display(
                    self.img_plot[i],
                    title,
                    "Frame:\n" + temp.fits_file.name,
                    temp.image,
                )
            except IndexError:
                self.img_plot[i].set_axis_off()
                text_nodata = "No data found."
                self.img_plot[i].text(
                    0.1,
                    0.5,
                    text_nodata,
                    color="#11557c",
                    fontsize=18,
                    ha="left",
                    va="center",
                    alpha=1.0,
                    transform=self.img_plot[i].transAxes,
                )
                self.img_plot[i].tooltip = "No data found"

    def _add_nodata_subplots(self, figure):
        self.img_plot = figure.add_subplot(1, 1, 1)

    def _nodata_plot(self):
        # could be moved to reflex library?
        self.img_plot.set_axis_off()
        text_nodata = "Data not found. Expected input file types:\n%s" % "\n".join(
            self.img_cat_list
        )
        self.img_plot.text(
            0.1,
            0.6,
            text_nodata,
            color="#11557c",
            fontsize=18,
            ha="left",
            va="center",
            alpha=1.0,
        )
        self.img_plot.tooltip = "No data found"

    def setWindowHelp(self):
        help_text = """
This is an interactive window which help asses the quality of the execution of a recipe.
"""
        return help_text


# This is the 'main' function
if __name__ == "__main__":
    from reflex_interactive_app import PipelineInteractiveApp

    # Create interactive application
    interactive_app = PipelineInteractiveApp(enable_init_sop=True)

    # get inputs from the command line
    interactive_app.parse_args()

    # Check if import failed or not
    if not import_success:
        interactive_app.setEnableGUI(False)

    # Open the interactive window if enabled
    if interactive_app.isGUIEnabled():
        # Get the specific functions for this window
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    # Print outputs. This is parsed by the Reflex python actor to
    # get the results. Do not remove
    interactive_app.print_outputs()
    sys.exit()
