import traceback

import hvplot.pandas  # noqa
import panel as pn
import param
from astropy.visualization import ZScaleInterval, PercentileInterval, MinMaxInterval, ManualInterval, BaseInterval
from panel.viewable import Viewer, Viewable

from edpsgui.domain.utils import HDUInfo, get_table
from edpsplot import Image, FitsFile, SDPSpectrum


class DataViewer(Viewer):
    filename = param.String(allow_refs=True)
    category = param.String(allow_refs=True)
    hdu_info = param.ClassSelector(class_=HDUInfo, allow_refs=True)

    def __init__(self, **params):
        super().__init__(**params)

    def __panel__(self) -> Viewable:
        return pn.pane.Markdown(f"No data viewer for HDU {self.hdu_info}")


class ImageViewer(DataViewer):
    interval = param.Selector(objects=['zscale', 'minmax', 'percentile', 'custom'],
                              label='Select Interval')

    def __init__(self, **params):
        super().__init__(**params)
        self.image_button = pn.widgets.Button(name=f'Display Image #{self.hdu_info.index}', button_type='primary')
        self.interval_selector = pn.widgets.RadioBoxGroup.from_param(self.param.interval, inline=True)
        self.percentile_input = pn.widgets.FloatInput(name='Percentile', value=99.5, start=0, end=100, format='0.0',
                                                      width=100, disabled=True)
        self.min_input = pn.widgets.FloatInput(name='Min', value=0, format='0.0', width=100, disabled=True)
        self.max_input = pn.widgets.FloatInput(name='Max', value=0, format='0.0', width=100, disabled=True)

    @pn.depends('interval', watch=True)
    def enable_relevant_inputs(self):
        if self.interval == 'custom':
            self.min_input.disabled = False
            self.max_input.disabled = False
            self.percentile_input.disabled = True
        elif self.interval == 'percentile':
            self.min_input.disabled = True
            self.max_input.disabled = True
            self.percentile_input.disabled = False
        else:
            self.min_input.disabled = True
            self.max_input.disabled = True
            self.percentile_input.disabled = True

    def get_interval(self) -> BaseInterval:
        if self.interval_selector.value == 'zscale':
            return ZScaleInterval()
        elif self.interval_selector.value == 'minmax':
            return MinMaxInterval()
        elif self.interval_selector.value == 'percentile':
            return PercentileInterval(self.percentile_input.value)
        elif self.interval_selector.value == 'custom':
            if self.min_input.value >= self.max_input.value:
                raise ValueError('Min value must be less than Max value')
            return ManualInterval(self.min_input.value, self.max_input.value)
        else:
            raise ValueError(f'Unknown interval type: {self.interval_selector.value}')

    def image_viewer(self, event):
        if not event:
            return
        plotter = Image(category=self.category, ext=self.hdu_info.index, interval=self.get_interval(), width=600)
        fits_file = FitsFile(name=self.filename, category=self.category)
        with self.image_button.param.update(loading=True, disabled=True):
            try:
                return plotter.plot([fits_file])
            except Exception as e:
                return pn.pane.Str(f'Error plotting image: {e}\n{traceback.format_exc()}')

    def __panel__(self) -> Viewable:
        layout = pn.Column(
            pn.Row(self.interval_selector),
            pn.Row(self.percentile_input, self.min_input, self.max_input),
            self.image_button,
            pn.Spacer(height=50),
            pn.bind(self.image_viewer, self.image_button)
        )
        return layout


class TableViewer(DataViewer):

    def __init__(self, **params):
        super().__init__(**params)
        self.table_button = pn.widgets.Button(name=f'Display Table #{self.hdu_info.index}', button_type='primary')

    def table_viewer(self, event):
        if not event:
            return
        with self.table_button.param.update(loading=True, disabled=True):
            try:
                df = get_table(self.filename, self.hdu_info.index)
                if len(df) > 1000:
                    df = df.sample(1000)
                    return pn.Column(
                        pn.pane.Str('Table is too large, displaying a sample of 1000 rows'),
                        pn.widgets.Tabulator(df, show_index=False, disabled=True, page_size=200, theme='default')
                    )
                else:
                    return pn.widgets.Tabulator(df, show_index=False, disabled=True, page_size=200, theme='default')
            except Exception as e:
                return pn.pane.Str(f'Error plotting table: {e}')

    def __panel__(self) -> Viewable:
        layout = pn.Column(
            self.table_button,
            pn.bind(self.table_viewer, self.table_button)
        )
        return layout


class SpectrumViewer(DataViewer):

    def __init__(self, **params):
        super().__init__(**params)
        self.spectrum_button = pn.widgets.Button(name=f'Display Spectrum #{self.hdu_info.index}', button_type='primary')

    def spectrum_viewer(self, event):
        if not event:
            return
        plotter = SDPSpectrum(category=self.category, x_column='WAVE', y_column='FLUX', width=600,
                              color='blue', line_width=1)
        fits_file = FitsFile(name=self.filename, category=self.category)
        with self.spectrum_button.param.update(loading=True, disabled=True):
            try:
                return plotter.plot([fits_file])
            except Exception as e:
                return pn.pane.Str(f'Error plotting spectrum: {e}\n{traceback.format_exc()}')

    def __panel__(self) -> Viewable:
        layout = pn.Column(
            self.spectrum_button,
            pn.bind(self.spectrum_viewer, self.spectrum_button)
        )
        return layout
