import os.path
from dataclasses import dataclass
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from astropy.io import fits
from astropy.table import Table
from astropy.visualization import ZScaleInterval, BaseInterval, hist


def plot(data, title: str, interval: BaseInterval):
    vmin, vmax = interval.get_limits(data)
    fig = plt.figure()
    plt.imshow(data, vmin=vmin, vmax=vmax, cmap='gray')
    plt.colorbar()
    plt.title(title)
    plt.close(fig)
    return fig


def is_image_hdu(header):
    xtension = header.get('xtension', 'PRIMARY')
    naxis = header.get('naxis', 0)
    naxis1 = header.get('naxis1', 0)
    naxis2 = header.get('naxis2', 0)
    return naxis == 2 and naxis1 > 0 and naxis2 > 0 and xtension in ('IMAGE', 'PRIMARY')


def plot_image(filename, interval=ZScaleInterval(), hdunum=0):
    title = f'{os.path.basename(filename)} HDU#{hdunum}'
    with fits.open(filename) as hdul:
        hdu = hdul[hdunum]
        if is_image_hdu(hdu.header):
            return plot(hdu.data, title, interval)
        else:
            return None


def plot_hist(filename, hdunum=0):
    with fits.open(filename) as hdul:
        data = hdul[hdunum].data.reshape(-1)
        fig = plt.figure()
        plt.semilogy()
        hist(data, bins=50)
        plt.close(fig)
        return fig


def get_table(filename: str, hdunum: int) -> pd.DataFrame:
    with fits.open(filename) as hdul:
        return Table.read(hdul[hdunum]).to_pandas()


@dataclass
class LabelledArray:
    data: Any
    unit: str
    name: str


@dataclass
class Spectrum:
    wave: LabelledArray
    flux: LabelledArray
    err: LabelledArray
    snr: LabelledArray


def get_spectrum(filename: str, hdunum: int = 1) -> Spectrum:
    with fits.open(filename) as hdul:
        data = hdul[hdunum].data
        columns = hdul[hdunum].columns
        wave = LabelledArray(
            data=data[0]['WAVE'],
            unit=columns['WAVE'].unit,
            name='Wavelength'
        )
        flux = LabelledArray(
            data=data[0]['FLUX'],
            unit=columns['FLUX'].unit,
            name='Flux'
        )
        err = LabelledArray(
            data=data[0]['ERR'],
            unit=columns['ERR'].unit,
            name='Error'
        )
        snr = LabelledArray(
            data=np.divide(flux.data, err.data, out=np.full_like(flux.data, np.nan, dtype=float), where=err.data != 0),
            unit='',
            name='S/N'
        )
        return Spectrum(
            wave=wave,
            flux=flux,
            err=err,
            snr=snr
        )
