import asyncio
import dataclasses as dc
import logging
import os
import time
from datetime import datetime

import panel as pn
import param
from panel.viewable import Viewer, Viewable
from panel_jstree import Tree
from panel_modal import Modal

from edpsgui.domain.dataset import *
from edpsgui.domain.edpsw import edps_banner
from edpsgui.domain.utils import get_keywords
from edpsgui.gui.edps_ctl import get_edps
from edpsgui.gui.reduction_config_editor import ReductionConfiguration, ReductionConfigurationEditor
from edpsgui.gui.reduction_table import ReductionTable

OBS_TARG = 'HIERARCH ESO OBS TARG NAME'
SHOW_DATASET_TREE = 'Show Dataset Tree'
EDIT_CONFIG = 'Edit Configuration'


@dc.dataclass
class ReductionStatus:
    created: int
    running: int
    completed: int
    failed: int

    def total(self) -> int:
        return self.created + self.running + self.completed + self.failed

    def get_completion_percentage(self) -> int:
        res = (self.completed + self.failed) / self.total() * 100
        return round(res)

    def __str__(self):
        return f"created={self.created}, running={self.running}, completed={self.completed}, failed={self.failed}"


def get_reduction_status(jobs, reduction_start_date) -> ReductionStatus:
    created = 0
    running = 0
    completed = 0
    failed = 0
    for job in jobs:
        status = job.status.value
        if status == 'CREATED':
            created += 1
        elif status == 'RUNNING':
            running += 1
        elif status == 'COMPLETED':
            completed += 1
        elif status == 'FAILED':
            if job.completion_date < reduction_start_date:
                created += 1
            else:
                failed += 1
    return ReductionStatus(created=created, running=running, completed=completed, failed=failed)


class Reducer(Viewer):
    edps_status = param.Boolean(default=None, allow_refs=True)
    workflow = param.String(default=None, allow_refs=True)
    dataset_df = param.DataFrame(allow_refs=True)
    datasets = param.Dict(allow_refs=True)
    selected_datasets = param.List(default=[])
    current_dataset_id = param.String()
    selected_files = param.List(default=[])
    reduction_running = param.Boolean()
    config_counter = param.Integer()

    def __init__(self, **params):
        super().__init__(**params)
        self.edps = get_edps()
        self.dataset_config: Dict[str, ReductionConfiguration] = {}
        self.terminal = pn.widgets.Terminal(edps_banner, height=300, sizing_mode='stretch_width')
        self.logger = self.create_logger()

        self.reduce_datasets_btn = pn.widgets.Button(name='Reduce Selected Datasets', button_type='success',
                                                     disabled=self.reduction_disabled)
        self.reduce_datasets_btn.on_click(self.reduce_datasets)

        self.stop_reduction_btn = pn.widgets.Button(name='Stop Reduction', button_type='danger',
                                                    disabled=self.reduction_not_running)
        self.stop_reduction_btn.on_click(self.stop_reduction)

        self.dataset_table = self.create_dataset_table()
        self.dataset_pane = pn.Column(width=800, height=600, scroll=True)
        self.dataset_modal = Modal(self.dataset_pane)
        self.config_editor = None
        self.configuration_pane = self.create_configuration_pane()
        self.configuration_modal = Modal(self.configuration_pane)

    @pn.depends('dataset_df', watch=True)
    def update_dataset_table(self):
        if self.dataset_df is not None and self.dataset_df.empty:
            self.dataset_table.param.trigger('value')

    def get_dataset_reductions(self, row):
        dataset = row['Dataset']
        dataset_reductions = self.edps.get_reductions_by_dataset(dataset)
        if dataset_reductions:
            return pn.Row(pn.Spacer(width=50), ReductionTable(reductions=dataset_reductions, job_icon=True))
        else:
            return pn.pane.HTML(f"No reductions")

    def create_dataset_table(self):
        formatters = {
            'Progress': {'type': 'progress', 'max': 100},
            'Complete': {'type': 'tickCross'}
        }
        buttons = {SHOW_DATASET_TREE: "<i class='fa-regular fa-file-lines'></i>",
                   EDIT_CONFIG: "<i class='fa-solid fa-gear'></i>"}
        dataset_table = pn.widgets.Tabulator.from_param(self.param.dataset_df,
                                                        show_index=False, hidden_columns=['ID'],
                                                        selectable='checkbox', buttons=buttons,
                                                        pagination='local', page_size=20, formatters=formatters,
                                                        disabled=True, row_content=self.get_dataset_reductions)
        dataset_table.on_click(self.dataset_table_callback)
        pn.bind(self.update_selection, dataset_table.param.selection, watch=True)
        return dataset_table

    def create_configuration_pane(self):
        apply_current_btn = pn.widgets.Button(name='Apply to current dataset', button_type='primary')
        apply_current_btn.on_click(self.apply_config_to_current_dataset)
        apply_selected_btn = pn.widgets.Button(name='Apply to selected datasets', button_type='primary')
        apply_selected_btn.on_click(self.apply_config_to_selected_datasets)
        apply_all_btn = pn.widgets.Button(name='Apply to all datasets', button_type='primary')
        apply_all_btn.on_click(self.apply_config_to_all_datasets)
        return pn.Column(
            "# Data Reduction Configuration",
            pn.widgets.StaticText(name='Current dataset', value=self.current_dataset_name),
            pn.widgets.StaticText(name='Selected datasets', value=self.selected_dataset_names),
            pn.Row(apply_current_btn, apply_selected_btn, apply_all_btn),
            pn.layout.Divider(),
            pn.layout.Divider(),  # this is a placeholder!
            width=700, height=700, scroll=True
        )

    def create_logger(self):
        logger = logging.getLogger("terminal")
        logger.setLevel(logging.DEBUG)
        stream_handler = logging.StreamHandler(self.terminal)
        stream_handler.terminator = "  \n"
        formatter = logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s")
        stream_handler.setFormatter(formatter)
        stream_handler.setLevel(logging.DEBUG)
        logger.addHandler(stream_handler)
        return logger

    def apply_config_to_current_dataset(self, event):
        if event:
            self.update_dataset_config(self.current_dataset_name())
            self.configuration_modal.is_open = False

    def apply_config_to_selected_datasets(self, event):
        if event:
            for dataset_name in self.selected_dataset_names():
                self.update_dataset_config(dataset_name)
            self.configuration_modal.is_open = False

    def apply_config_to_all_datasets(self, event):
        if event:
            for dataset in self.datasets.values():
                self.update_dataset_config(dataset.dataset_name)
            self.configuration_modal.is_open = False

    def update_dataset_config(self, dataset: str):
        config = self.config_editor.get_reduction_configuration()
        self.dataset_config[dataset] = ReductionConfiguration(comment=config.comment,
                                                              report_type=config.report_type,
                                                              parameter_set=config.parameter_set,
                                                              workflow_parameters=config.workflow_parameters,
                                                              recipe_parameters=config.recipe_parameters)
        self.config_counter += 1

    @pn.depends('edps_status')
    def edps_not_running(self):
        return not self.edps.is_running()

    @pn.depends('edps_status', 'selected_datasets')
    def reduction_disabled(self):
        return not self.edps.is_running() or self.selected_datasets == []

    @pn.depends('reduction_running')
    def reduction_not_running(self):
        return not self.reduction_running

    @pn.depends('selected_datasets')
    def selected_dataset_names(self):
        return [self.datasets[dataset_id].dataset_name for dataset_id in self.selected_datasets]

    @pn.depends('current_dataset_id')
    def current_dataset_name(self):
        if self.datasets and self.current_dataset_id:
            return self.datasets[self.current_dataset_id].dataset_name
        else:
            return ''

    def dataset_table_callback(self, event):
        self.current_dataset_id = self.dataset_table.value.iloc[event.row]['ID']
        dataset = self.datasets[self.current_dataset_id]
        if event.column == SHOW_DATASET_TREE:
            self.show_dataset_files(dataset)
        elif event.column == EDIT_CONFIG:
            self.open_config_editor()

    def show_dataset_files(self, dataset):
        while len(self.dataset_pane) > 0:
            self.dataset_pane.pop(0)
        dataset_tree = Tree(data=[dc.asdict(astree(dataset.dataset, opened=True))], checkbox=False, height=600)
        self.dataset_pane.append(f"## Dataset {dataset.dataset_name}")
        self.dataset_pane.append(dataset_tree)
        self.dataset_modal.is_open = True

    def get_dataset_keywords(self, dataset: NamedDataset) -> Dict[str, object]:
        keywords = [OBS_TARG]
        main_file = [filename for filename in get_dataset_files(dataset.dataset) if
                     os.path.basename(filename).startswith(dataset.dataset_name)]
        if len(main_file) != 1:
            self.logger.error('Could not find main file for dataset %s', dataset.dataset_name)
            return {kw: None for kw in keywords}
        else:
            return get_keywords(main_file[0], keywords)

    def open_config_editor(self):
        self.configuration_pane.pop(-1)
        config = self.dataset_config.get(self.current_dataset_name(), ReductionConfiguration())
        current_dataset_tasks = get_dataset_tasks(self.datasets[self.current_dataset_id].dataset)
        self.config_editor = ReductionConfigurationEditor(comment=config.comment, report_type=config.report_type,
                                                          param_set=config.parameter_set,
                                                          workflow_parameters=config.workflow_parameters,
                                                          recipe_parameters=config.recipe_parameters,
                                                          tasks=current_dataset_tasks,
                                                          workflow=self.workflow)
        self.configuration_pane.append(self.config_editor)
        self.configuration_modal.is_open = True

    def update_reduction_count(self, dataset_id):
        rows_to_update = self.dataset_table.value.loc[self.dataset_table.value['ID'] == dataset_id].index
        self.dataset_table.patch({
            'Reductions': [(row, self.dataset_table.value.iloc[row]['Reductions'] + 1) for row in rows_to_update]
        })

    async def reduce_datasets(self, event):
        with self.reduce_datasets_btn.param.update(loading=True, disabled=True):
            self.reduction_running = True
            jobs_to_monitor = {}
            for dataset_id in self.selected_datasets:
                dataset = self.datasets[dataset_id]
                dataset_name = dataset.dataset_name
                config = self.dataset_config.get(dataset_name, ReductionConfiguration())
                response = self.edps.reduce_dataset(workflow=self.workflow, dataset=dataset, config=config)
                self.update_reduction_count(dataset_id)
                jobs_to_monitor[dataset_id] = [job.job_id for job in response.jobs]
            start_time = time.time()
            await self.monitor_jobs(jobs_to_monitor)
            self.logger.info('Processed %d datasets in %.3fs', len(jobs_to_monitor), time.time() - start_time)
            self.reduction_running = False

    def should_update_progress(self, index, value):
        old_value = self.dataset_table.value.iloc[index]['Progress']
        return value > old_value

    def update_progress(self, index, value, force=False):
        if force or self.should_update_progress(index, value):
            self.dataset_table.patch({
                'Progress': [(index, value)]
            })

    async def monitor_jobs(self, jobs_to_monitor: Dict[str, List[str]]):
        stay_in_loop = True
        reduction_start_date = datetime.now().isoformat(timespec='milliseconds')
        for dataset_id in jobs_to_monitor.keys():
            rows_to_update = self.dataset_table.value.loc[self.dataset_table.value['ID'] == dataset_id].index
            self.update_progress(rows_to_update[0], 5, force=True)
        while stay_in_loop:
            if not self.edps_status:
                break
            created_or_running = 0
            for dataset_id, job_ids in jobs_to_monitor.items():
                jobs = [self.edps.get_job_details(job_id) for job_id in job_ids]
                status = get_reduction_status(jobs, reduction_start_date)
                dataset_name = self.datasets[dataset_id].dataset_name
                if status.running > 0:
                    self.logger.debug(f"%s %s, progress %d%%", dataset_name, status, status.get_completion_percentage())
                rows_to_update = self.dataset_table.value.loc[self.dataset_table.value['ID'] == dataset_id].index
                self.update_progress(rows_to_update[0], status.get_completion_percentage())
                created_or_running += status.created + status.running
            stay_in_loop = self.reduction_running and created_or_running > 0
            await asyncio.sleep(5)

    def update_selection(self, event):
        # skip spurious events generated when updating the progress column during execution
        if self.reduction_running:
            raise pn.param.Skip
        dataset_ids = self.dataset_table.value.iloc[self.dataset_table.selection]['ID'].to_list()
        datasets = [self.datasets[dataset_id] for dataset_id in dataset_ids]
        dataset_names = [ds.dataset_name for ds in datasets]
        self.selected_datasets = dataset_ids
        self.selected_files = [get_dataset_files(dataset.dataset) for dataset in datasets]
        self.logger.debug('Selected datasets: %s', dataset_names)
        self.logger.debug('Selected files: %s', self.selected_files)

    def stop_reduction(self, event):
        with self.stop_reduction_btn.param.update(loading=True, disabled=True):
            self.edps.stop_reductions()
            self.reduction_running = False

    @pn.depends('config_counter')
    def dataset_configurations(self):
        if not self.dataset_config:
            return pn.pane.Markdown('No dataset configurations')
        else:
            items = [f'- **{key}**: {value}' for key, value in self.dataset_config.items()]
            return pn.pane.Markdown('\n'.join(items))

    @pn.depends('dataset_table.expanded')
    def all_collapsed(self):
        return self.dataset_table.expanded == []

    def collapse_rows(self, event):
        self.dataset_table.expanded = []

    def __panel__(self) -> Viewable:
        collapse_btn = pn.widgets.Button(name='Collapse', button_type='primary', disabled=self.all_collapsed)
        collapse_btn.on_click(self.collapse_rows)
        return pn.Column(
            pn.Row(self.reduce_datasets_btn, self.stop_reduction_btn, collapse_btn),
            pn.Row(self.dataset_table),
            self.configuration_modal,
            self.dataset_modal,
            self.terminal,
            "## Dataset Configurations",
            self.dataset_configurations
        )
