from __future__ import with_statement
from __future__ import absolute_import
from __future__ import print_function
import sys
import os.path

try:
    import numpy
    from astropy.table import Table
    import reflex
    from pipeline_product import PipelineProduct
    import pipeline_display
    import reflex_plot_widgets

    from matplotlib import gridspec, pylab, pyplot, transforms
    import pdb  # for debugging
    from collections import defaultdict  # to make dictionary of lists

    import_success = True

except ImportError:
    import_success = False
    print("Error importing modules pyfits, wx, matplotlib, numpy")


def paragraph(text, width=None):
    """ wrap text string into paragraph
       text:  text to format, removes leading space and newlines
       width: if not None, wraps text, not recommended for tooltips as
              they are wrapped by wxWidgets by default
    """
    import textwrap

    if width is None:
        return textwrap.dedent(text).replace("\n", " ").strip()
    else:
        return textwrap.fill(textwrap.dedent(text), width=width)


class DataPlotterManager(object):
    """
    This class must be added to the PipelineInteractiveApp with setPlotManager
    It must have following member functions which will be called by the app:
     - setInteractiveParameters(self)
     - readFitsData(self, fitsFiles):
     - addSubplots(self, figure):
     - plotProductsGraphics(self, figure, canvas)
    Following members are optional:
     - setWindowHelp(self)
     - setWindowTitle(self)
    """

    # static members

    recipe_name = "eris_nix_img_cal_phot"
    img_catg_list = ["CAL_PHOT_OBJECT_JITTER", "CAL_PHOT_STD_JITTER"]
    cat_catg = "CAL_PHOT_CATALOGUE"
    refcat_catg = "CAL_PHOT_REFCAT"
    matchcat_catg = "CAL_PHOT_MATCHCAT"

    nameroot = "cal_phot"

    def setWindowTitle(self):
        return self.recipe_name + "_interactive"

    def setInteractiveParameters(self):
        """
        This function specifies which are the parameters that should be presented
        in the window to be edited.  Note that the parameter has to also be in the
        in_sop port (otherwise it won't appear in the window). The descriptions are
        used to show a tooltip. They should match one to one with the parameter
        list.
        """

        # Only selected Recipe parameters are shown because list is too long
        # The selection shown ones likely to be wanted to be changed by user
        return [
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.obj.min-pixels",
                group="catalogue",
                description="Minimum pixel area for each detected object.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.obj.threshold",
                group="catalogue",
                description="Detection threshold in sigma above sky.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.obj.deblending",
                group="catalogue",
                description="Use deblending?.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.obj.core-radius",
                group="catalogue",
                description="Value of Rcore in pixels.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.bkg.estimate",
                group="catalogue",
                description="Estimate background from input, if false it is "
                            "assumed input is already background corrected "
                            "with median 0.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.bkg.mesh-size",
                group="catalogue",
                description="Background smoothing box size.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.bkg.smooth-gauss-fwhm",
                group="catalogue",
                description="The FWHM of the Gaussian kernel used in "
                            "convolution for object detection.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.det.effective-gain",
                group="catalogue",
                description="Detector gain value to rescale convert "
                            "intensity to electrons.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="catalogue.det.saturation",
                group="catalogue",
                description="Detector saturation value.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                group="source matching",
                displayName="cdssearch_photom",
                description="CDS photometric catalogue. <none | 2MASS>",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="pixel-radius",
                group="source matching",
                description="Max. distance between associated object and "
                            "catalogue entry (pixels).",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="minphotom",
                description="Min number of matched stars for "
                            "photometric calibration.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="magerrcut",
                description="Matched stars with magnitude "
                            "error above this value will not be used.",
            ),
            reflex.RecipeParameter(
                recipe=self.recipe_name,
                displayName="strict-classification",
                group="source matching",
                description="Match only low-ellipticity objects classified as "
                            "stellar.",
            ),
        ]

    def readFitsData(self, fitsFiles):
        """
        This function should be used to read and organize the raw fits files
        produced by the recipes.
        It receives as input a list of reflex.FitsFiles
        """

        # frames is a dict of keyword/list pairs where elements of list are
        # PipelineProducts. For each category the list contains all matching FITS
        # files in the input parameter list

        self.frames = defaultdict(list)

        for f in fitsFiles:
            self.frames[f.category].append(PipelineProduct(f))

        # look for images
        self.img_catg = None
        for catg in self.img_catg_list:
            if catg in self.frames:
                self.img_catg = catg
                break

        # find how many images we have, initialise pointers
        self.n_frames = 0
        if (len(self.frames[self.img_catg])) > 0:
            self.n_frames = len(self.frames[self.img_catg])
            self.cur_frame = 0
            self.n_extn = (
                    len(self.frames[self.img_catg][0].hdulist()) - 1
            )  # number of extensions, assumed to be same for all fitsFiles
            if self.n_extn > 4:
                self.n_extn = 4

            # Don't read the individual calibrated science images in here
            #   for memory/performance reasons. Read them in below as needed.
            #   There could be a lot of frames and user may not want to see all of them.

            # Sort the frames using PIPEFILE keyword
            self.frames[self.img_catg].sort(
                key=lambda foo: foo.all_hdu[0].header["PIPEFILE"]
            )

            # re-define esorex's pipeline_display plotting functions to enable callbacks
            self._add_subplots = self._add_subplots
            self._plot = self._data_plot

            # Define radio button options
            self.left_opts = ["Click to\ngo back"]
            self.right_opts = ["Click to\nadvance"]
        else:
            # Set the plotting functions to NODATA ones
            self._add_subplots = self._add_nodata_subplots
            self._plot = self._nodata_plot

    def addSubplots(self, figure):
        self._add_subplots(figure)

    def plotProductsGraphics(self):
        self._plot()

    def plotWidgets(self):
        widgets = list()

        # forward/back buttons
        # Only show them if more than one frame was found
        if self.n_frames > 1:
            self.radiobutton_left = reflex_plot_widgets.InteractiveRadioButtons(
                self.axradiobutton_left,
                self.setRadioCallback_left,
                self.left_opts,
                0,
                title="",
            )
            widgets.append(self.radiobutton_left)

            self.radiobutton_right = reflex_plot_widgets.InteractiveRadioButtons(
                self.axradiobutton_right,
                self.setRadioCallback_right,
                self.right_opts,
                0,
                title="",
            )
            widgets.append(self.radiobutton_right)

        # Radio button for plot overlays
        self.radiobutton_overlay = reflex_plot_widgets.InteractiveRadioButtons(
            self.axradiobutton_overlay,
            self.setRadioCallback_overlay,
            ("overlay", "no overlay"),
            0,
            title="",
        )
        widgets.append(self.radiobutton_overlay)

        # Adjust size of button boxes and font size of labels
        for i in range(len(widgets)):
            pos = widgets[i].rbuttons.ax.get_position()
            widgets[i].rbuttons.ax.set_position(
                transforms.Bbox([[pos.x0, pos.y0 - 0.01], [pos.x1, 0.97]])
            )
            for j in range(len(widgets[i].rbuttons.labels)):
                widgets[i].rbuttons.labels[j].set_fontsize(11)

        return widgets

    def setRadioCallback_left(self, label):

        # decrement (or wrap) frame number
        self.cur_frame -= 1
        if self.cur_frame < 0:
            self.cur_frame = self.n_frames - 1

        self._plot()

    def setRadioCallback_right(self, label):

        # advance (or wrap) frame number
        self.cur_frame += 1
        if self.cur_frame == self.n_frames:
            self.cur_frame = 0

        self._plot()

    def setRadioCallback_overlay(self, label):

        # replot for current overlay setting
        self.overlay = not self.overlay
        self._plot()

    def _add_subplots(self, figure):

        self.img_plot = []
        if self.n_frames > 0:
            gs = gridspec.GridSpec(nrows=9, ncols=4)
            gs.update(hspace=0.7)  # make space so axis labels dont overlap

            # buttons
            self.axradiobutton_overlay = figure.add_subplot(gs[0, 0])
            # don't blot subplot border around the buttons or filename
            for spine in self.axradiobutton_overlay.spines.values():
                spine.set_visible(False)
            self.overlay = True

            if self.n_frames > 1:
                self.axradiobutton_left = figure.add_subplot(gs[0, 1])
                self.axradiobutton_right = figure.add_subplot(gs[0, 2])

                # don't blot subplot border around the buttons or filename
                for spine in self.axradiobutton_left.spines.values():
                    spine.set_visible(False)
                for spine in self.axradiobutton_right.spines.values():
                    spine.set_visible(False)

            # filename
            self.filename_plot = figure.add_subplot(gs[1, 0:4])
            for spine in self.filename_plot.spines.values():
                spine.set_visible(False)

            # plots
            self.img_plot.append(figure.add_subplot(gs[2:6, 0:2]))
            self.img_plot.append(figure.add_subplot(gs[2:6, 2:4]))
            self.img_plot.append(figure.add_subplot(gs[6:9, 0:4]))

            # Move ticks to rhs for readability
            self.img_plot[1].yaxis.tick_right()

        else:
            gs = gridspec.GridSpec(2, 2)
            self.img_plot.append(figure.add_subplot(gs[0, 0]))
            self.img_plot.append(figure.add_subplot(gs[0, 1]))
            self.img_plot.append(figure.add_subplot(gs[1, 0]))
            self.img_plot.append(figure.add_subplot(gs[1, 1]))

    def _data_plot(self):

        self.filename_plot.clear()
        self.filename_plot.set_axis_off()
        temp = self.frames[self.img_catg][self.cur_frame]
        self.filename_plot.text(0.0, 0.5, os.path.basename(temp.fits_file.name))

        self.overplot_matched = False
        self.overplot_ref = False
        self.overplot_image = False

        for iext in range(self.n_extn):

            ext_name = (
                self.frames[self.img_catg][self.cur_frame]
                    .all_hdu[iext + 1]
                    .header["EXTNAME"]
            )

            i = None
            if ext_name == "DATA":
                i = 0

            if i is not None:

                title = "{} {}/{} {}".format(
                    self.nameroot, self.cur_frame + 1, self.n_frames, ext_name
                )

                # clear image frame and make it visible
                self.img_plot[i].cla()
                self.img_plot[i].tooltip = ""
                self.img_plot[i].set_visible(True)

                if i == 0:
                    pylab.setp(self.img_plot[i].get_yticklabels(), visible=True)
                pylab.setp(self.img_plot[i].get_xticklabels(), visible=True)

                # Setup the selected image and display it
                imgdisp = pipeline_display.ImageDisplay()
                imgdisp.setAspect("equal")
                imgdisp.setLabels("X", "Y")

                if self.overlay:

                    # Overplot matched catalogue if present
                    pname = os.path.basename(
                        self.frames[self.img_catg][self.cur_frame].fits_file.name
                    )
                    for frame in self.frames[self.matchcat_catg]:
                        cname = os.path.basename(frame.fits_file.name)
                        if pname in cname:
                            frame.readTableXYColumns(1, "X_coordinate", "Y_coordinate")

                            imgdisp.overplotScatter(
                                frame.x_column,
                                frame.y_column,
                                marker="o",
                                size=60,
                                color="lime",
                            )

                            self.overplot_matched = True
                            break

                    # Overplot reference catalogue if present
                    pname = os.path.basename(
                        self.frames[self.img_catg][self.cur_frame].fits_file.name
                    )
                    for frame in self.frames[self.refcat_catg]:
                        cname = os.path.basename(frame.fits_file.name)
                        if pname in cname:
                            frame.readTableXYColumns(1, "xpredict", "ypredict")

                            imgdisp.overplotScatter(
                                frame.x_column,
                                frame.y_column,
                                marker="o",
                                size=25,
                                color="red",
                            )

                            self.overplot_ref = True
                            break

                    # Overplot image catalogue if present
                    pname = os.path.basename(
                        self.frames[self.img_catg][self.cur_frame].fits_file.name
                    )
                    for frame in self.frames[self.cat_catg]:
                        cname = os.path.basename(frame.fits_file.name)
                        if pname in cname:
                            frame.readTableXYColumns(1, "X_coordinate", "Y_coordinate")

                            imgdisp.overplotScatter(
                                frame.x_column,
                                frame.y_column,
                                marker="o",
                                color="dodgerblue",
                                size=10,
                            )
                            self.overplot_image = True
                            break

                # try reading image to plot
                try:
                    temp = self.frames[self.img_catg][self.cur_frame]
                    temp.readImage(iext + 1)
                    imgdisp.display(
                        self.img_plot[i],
                        title,
                        "Frame:\n" + temp.fits_file.name,
                        temp.image,
                    )
                except IndexError:
                    self.img_plot[i].set_axis_off()
                    text_nodata = "No data found."
                    self.img_plot[i].text(
                        0.1,
                        0.5,
                        text_nodata,
                        color="#11557c",
                        fontsize=18,
                        ha="left",
                        va="center",
                        alpha=1.0,
                        transform=self.img_plot[i].transAxes,
                    )
                    self.img_plot[i].tooltip = "No data found"

            i = 1

            # use this pane to plot results

            self.img_plot[i].cla()
            self.img_plot[i].set_visible(True)
            self.img_plot[i].set_axis_off()
            self.img_plot[i].tooltip = "Results:"

            method = (
                self.frames[self.img_catg][self.cur_frame]
                    .all_hdu[0]
                    .header["ZPMETHOD"]
            )
            fluxcal = (
                self.frames[self.img_catg][self.cur_frame]
                    .all_hdu[0]
                    .header["FLUXCAL"]
            )
            magzpt = (
                self.frames[self.img_catg][self.cur_frame]
                    .all_hdu[0]
                    .header["HIERARCH ESO QC MAGZPT"]
            )
            magzerr = (
                self.frames[self.img_catg][self.cur_frame]
                    .all_hdu[0]
                    .header["HIERARCH ESO QC MAGZERR"]
            )

            if method == "DEFAULT":
                self.img_plot[i].text(
                    0.0,
                    0.9,
                    "Phot calibration read from defaults file",
                    color="black",
                    fontsize=10,
                    ha="left",
                    va="center",
                    alpha=1.0,
                    transform=self.img_plot[i].transAxes,
                )
                self.img_plot[i].text(
                    0.0,
                    0.8,
                    "MAGZPT: %2.2f +- %1.2f" % (magzpt, magzerr),
                    color="black",
                    fontsize=10,
                    ha="left",
                    va="center",
                    alpha=1.0,
                    transform=self.img_plot[i].transAxes,
                )

                i = 2

                # this pane is blank for default results

                self.img_plot[i].cla()
                self.img_plot[i].set_visible(False)

            elif method == "2MASS":

                catname = (
                    self.frames[self.img_catg][self.cur_frame]
                        .all_hdu[0]
                        .header["ZP_CAT"]
                )

                self.img_plot[i].text(
                    0.0,
                    0.9,
                    "Overall result from matched",
                    color="black",
                    fontsize=10,
                    ha="left",
                    va="center",
                    transform=self.img_plot[i].transAxes,
                )
                self.img_plot[i].text(
                    0.0,
                    0.8,
                    "photometric standards:",
                    color="black",
                    fontsize=10,
                    ha="left",
                    va="center",
                    transform=self.img_plot[i].transAxes,
                )
                self.img_plot[i].text(
                    0.1,
                    0.7,
                    "FLUXCAL: %s" % (fluxcal),
                    color="black",
                    fontsize=10,
                    ha="left",
                    va="center",
                    transform=self.img_plot[i].transAxes,
                )
                self.img_plot[i].text(
                    0.1,
                    0.6,
                    "MAGZPT: %2.2f +- %1.2f" % (magzpt, magzerr),
                    color="black",
                    fontsize=10,
                    ha="left",
                    va="center",
                    transform=self.img_plot[i].transAxes,
                )
                self.img_plot[i].text(
                    0.1,
                    0.5,
                    "Catalogue: %s" % catname,
                    color="black",
                    fontsize=10,
                    ha="left",
                    va="center",
                    transform=self.img_plot[i].transAxes,
                )

                self.img_plot[i].text(
                    0.0,
                    0.35,
                    "For this jitter:",
                    color="black",
                    fontsize=10,
                    ha="left",
                    va="center",
                    transform=self.img_plot[i].transAxes,
                )

                ypos = 0.25

                if self.overplot_matched:
                    self.img_plot[i].scatter(
                        [0.05],
                        [ypos],
                        marker="o",
                        s=60,
                        c="lime",
                        transform=self.img_plot[i].transAxes,
                    )
                    self.img_plot[i].text(
                        0.1,
                        ypos,
                        "image objects matched in reference catalogue",
                        color="black",
                        fontsize=10,
                        ha="left",
                        va="center",
                        alpha=1.0,
                        transform=self.img_plot[i].transAxes,
                    )
                else:
                    self.img_plot[i].text(
                        0.05,
                        ypos,
                        "No matched catalogue sources",
                        color="black",
                        fontsize=10,
                        ha="left",
                        va="center",
                        alpha=1.0,
                        transform=self.img_plot[i].transAxes,
                    )
                ypos = ypos - 0.1

                if self.overplot_ref:
                    self.img_plot[i].scatter(
                        [0.05],
                        [ypos],
                        s=30,
                        c="red",
                        transform=self.img_plot[i].transAxes,
                    )
                    self.img_plot[i].text(
                        0.1,
                        ypos,
                        "objects in reference catalogue",
                        color="black",
                        fontsize=10,
                        ha="left",
                        va="center",
                        transform=self.img_plot[i].transAxes,
                    )
                else:
                    self.img_plot[i].text(
                        0.05,
                        ypos,
                        "No reference catalogue sources",
                        color="black",
                        fontsize=10,
                        ha="left",
                        va="center",
                        transform=self.img_plot[i].transAxes,
                    )
                ypos = ypos - 0.1

                if self.overplot_image:
                    self.img_plot[i].scatter(
                        [0.05],
                        [ypos],
                        s=10,
                        c="dodgerblue",
                        transform=self.img_plot[i].transAxes,
                    )
                    self.img_plot[i].text(
                        0.1,
                        ypos,
                        "objects in image",
                        color="black",
                        fontsize=10,
                        ha="left",
                        va="center",
                        alpha=1.0,
                        transform=self.img_plot[i].transAxes,
                    )
                else:
                    self.img_plot[i].text(
                        0.05,
                        ypos,
                        "No sources in image",
                        color="black",
                        fontsize=10,
                        ha="left",
                        va="center",
                        transform=self.img_plot[i].transAxes,
                    )
                ypos = ypos - 0.1

                i = 2

                # use this pane to show table of matched standards

                self.img_plot[i].cla()
                self.img_plot[i].set_visible(True)
                self.img_plot[i].set_axis_off()
                self.img_plot[i].tooltip = "Matched standards"

                method = (
                    self.frames[self.img_catg][self.cur_frame]
                        .all_hdu[0]
                        .header["ZPMETHOD"]
                )

                pname = os.path.basename(
                    self.frames[self.img_catg][self.cur_frame].fits_file.name
                )

                for frame in self.frames[self.cat_catg]:
                    cname = os.path.basename(frame.fits_file.name)
                    if pname in cname:
                        # some info to help understand the numbers

                        self.img_plot[i].text(
                            0.1,
                            0.81,
                            "dm3 = refmag + 2.5log10(Aper_flux_3) + apcor3 + extinct",
                            color="black",
                            fontsize=10,
                            ha="left",
                            va="center",
                            alpha=1.0,
                            transform=self.img_plot[i].transAxes,
                        )
                        self.img_plot[i].text(
                            0.1,
                            0.73,
                            "extinct = atm_extcoeff * (airmass - 1.0)",
                            color="black",
                            fontsize=10,
                            ha="left",
                            va="center",
                            alpha=1.0,
                            transform=self.img_plot[i].transAxes,
                        )

                        # get APCOR3 from image catalogue
                        apcor3 = frame.all_hdu[1].header["APCOR3"]
                        self.img_plot[i].text(
                            0.1,
                            0.65,
                            "apcor3 = %f" % apcor3,
                            color="black",
                            fontsize=10,
                            ha="left",
                            va="center",
                            alpha=1.0,
                            transform=self.img_plot[i].transAxes,
                        )

                for frame in self.frames[self.matchcat_catg]:
                    cname = os.path.basename(frame.fits_file.name)
                    if pname in cname:

                        table = Table.read(frame.fits_file.name, hdu=1, )

                        columns = [
                            "_2MASS",
                            "refmag",
                            "Aper_flux_3",
                            "dm3",
                        ]
                        cell_text = []

                        for row in table[:10]:
                            row_text = []
                            for col in columns:
                                if i < 9:
                                    try:
                                        row_text.append("%.2f" % row[col])
                                    except:
                                        row_text.append(str(row[col]))
                                else:
                                    row_text.append("truncated")
                            cell_text.append(row_text)

                        if len(cell_text) > 0:
                            colnames = [
                                "_2MASS",
                                "refmag",
                                "Aper_flux_3 [DN/s]",
                                "dm3",
                            ]

                            self.img_plot[i].table(
                                cellText=cell_text,
                                colLabels=colnames,
                                bbox=[0.0, 0.0, 1.0, len(table[:10]) * 0.1],
                            )
                        else:
                            self.img_plot[i].text(
                                0.1,
                                0.5,
                                "No matched standards in this field",
                                color="black",
                                fontsize=10,
                                ha="left",
                                va="center",
                                alpha=1.0,
                                transform=self.img_plot[i].transAxes,
                            )

                        break

    def _add_nodata_subplots(self, figure):
        self.img_plot = figure.add_subplot(1, 1, 1)

    def _nodata_plot(self):
        # could be moved to reflex library?
        self.img_plot.set_axis_off()
        text_nodata = "Data not found. Expected input file types:\n%s" % "\n".join(
            self.img_catg_list
        )
        self.img_plot.text(
            0.1,
            0.6,
            text_nodata,
            color="#11557c",
            fontsize=18,
            ha="left",
            va="center",
            alpha=1.0,
        )
        self.img_plot.tooltip = "No data found"

    def setWindowHelp(self):
        help_text = """
This is an interactive window which help asses the quality of the execution of a recipe.
"""
        return help_text


# This is the 'main' function
if __name__ == "__main__":
    from reflex_interactive_app import PipelineInteractiveApp

    # Create interactive application
    interactive_app = PipelineInteractiveApp(enable_init_sop=True)

    # get inputs from the command line
    interactive_app.parse_args()

    # Check if import failed or not
    if not import_success:
        interactive_app.setEnableGUI(False)

    # Open the interactive window if enabled
    if interactive_app.isGUIEnabled():
        # Get the specific functions for this window
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    # Print outputs. This is parsed by the Reflex python actor to
    # get the results. Do not remove
    interactive_app.print_outputs()
    sys.exit()
