import numpy as np
import copy
from astropy.table import Table

use_PIPE9556_workaround=False

# ------------------------------------------------------------------------------------------
def auto_molecules_set_WL_include_ranges(
    filenames, wl_range, buff=0., wl_scale=1.,
    ) :
    WL_include_ranges=None
    if 'AUTO_MOLECULES' in filenames.keys() :
        t=Table.read(filenames['AUTO_MOLECULES'],hdu=1, format='fits')
        WL_include_ranges={}
        t['wl_range']*=wl_scale
        for r in t :
            if (
                ( wl_range[0] <= r['wl_range'][0][0] and r['wl_range'][0][0] <= wl_range[1] )
                or
                ( wl_range[0] <= r['wl_range'][0][1] and r['wl_range'][0][1] <= wl_range[1] )
                or
                ( r['wl_range'][0][0] <= wl_range[0] and wl_range[1] <= r['wl_range'][0][1] )
            ) :
                if r['molecule'] not in WL_include_ranges.keys() :
                    WL_include_ranges[r['molecule']]=[]
                WL_include_ranges[r['molecule']]+=[(r['wl_range']+np.array([-1.,1.])*buff)[0]]
    return WL_include_ranges
# ------------------------------------------------------------------------------------------
def merge_overlapping_wl_ranges( wl_ranges ) :
    if len(wl_ranges) == 0 :
        return wl_ranges
    # https://stackoverflow.com/questions/43600878/merging-overlapping-intervals
    wl_ranges.sort(key=lambda interval: interval[0])
    merged = [wl_ranges[0]]
    for current in wl_ranges:
        previous = merged[-1]
        if current[0] <= previous[1]:
            previous[1] = max(previous[1], current[1])
        else:
            merged.append(current)
    return merged
# ------------------------------------------------------------------------------------------
def auto_molecules_set_WL_ranges_to_include(
        files,
        wl_ranges,
        wl_span,
        wl_resolutions,
        wl_scale=1.,
        buff=0.,
        fallback_WL_include_ranges={},
        wl_exclude_gaps=[[],],
    ) :
    filenames={}
    if files is not None :
        for file in files:
            filenames[file.category]=file.name
            
    wave_inc_opt_params={
        'FIT_CONTINUUM': 'CONT_FIT_FLAG',
        'CONTINUUM_N': 'CONT_POLY_ORDER',
        'MAP_REGIONS_TO_CHIP': 'MAPPED_TO_CHIP',
    }

    multi_chip=True
    if len(np.shape(wl_ranges)) == 1 :
        wl_ranges=[wl_ranges,]
        multi_chip=False
    if len(np.shape(wl_resolutions)) == 0 :
        wl_resolutions=[wl_resolutions,]
    if len(wl_exclude_gaps) == 0 :
        wl_exclude_gaps=[wl_exclude_gaps,]

    WL_ranges_to_include={'species':[], 'wl_ranges': []}

    if 'MOLECULES' in filenames.keys() and 'WAVE_INCLUDE' in filenames.keys() :
        # MOLECULES
        t=Table.read(filenames['MOLECULES'], format='fits')
        fm=[]
        rc=[]
        for r in t :
            WL_ranges_to_include['species']+=[r['LIST_MOLEC'].strip(),]
            fm+=[r['FIT_MOLEC']]
            rc+=[r['REL_COL']]
        WL_ranges_to_include['fit_mol']=",".join(["%d" %(_x) for _x in fm])
        WL_ranges_to_include['rel_col']=",".join(["%f" %(_x) for _x in rc])
        # WAVE_INCLUDE
        t=Table.read(filenames['WAVE_INCLUDE'], format='fits')
        t['LOWER_LIMIT']*=wl_scale
        t['UPPER_LIMIT']*=wl_scale
        t.sort('LOWER_LIMIT')
        WL_ranges_to_include['wl_ranges']=[]
        for _p in wave_inc_opt_params.keys() :
            WL_ranges_to_include[_p]=[]
        for r in t :
            extn=0
            for _p in wave_inc_opt_params.keys() :
                if wave_inc_opt_params[_p] in t.columns :
                    WL_ranges_to_include[_p]+=[r[wave_inc_opt_params[_p]],]
                    if _p == 'MAP_REGIONS_TO_CHIP' :
                        extn=r[_p]
            if r['LOWER_LIMIT'] > wl_ranges[extn][0] and r['UPPER_LIMIT'] < wl_ranges[extn][1] :
                WL_ranges_to_include['wl_ranges']+=[[r['LOWER_LIMIT'],r['UPPER_LIMIT']]]
            pass
        WL_ranges_to_include['wl_ranges']=merge_overlapping_wl_ranges(WL_ranges_to_include['wl_ranges'])

        if 'WAVE_EXCLUDE' in filenames.keys() :
            t=Table.read(filenames['WAVE_EXCLUDE'], format='fits')
            t['LOWER_LIMIT']*=wl_scale
            t['UPPER_LIMIT']*=wl_scale
            t.sort('LOWER_LIMIT')
            WL_ranges_to_include['wl_exclude_ranges']=[]
            for r in t :
                WL_ranges_to_include['wl_exclude_ranges']+=[[r['LOWER_LIMIT'],r['UPPER_LIMIT']]]

    else :
        if multi_chip :
            WL_ranges_to_include['MAP_REGIONS_TO_CHIP']=[]
        for (extn,wl_range) in enumerate(wl_ranges) :
            WL_ranges_to_include['wl_ranges']=[]
            wl_resolution=wl_resolutions[extn]
            WL_include_ranges=auto_molecules_set_WL_include_ranges(
                filenames,
                wl_range,
                buff=buff,
                wl_scale=wl_scale,
            ) or fallback_WL_include_ranges
            for M in WL_include_ranges.keys() :
                for _r in WL_include_ranges[M] :
                    # consider the wl_range trimmed by 1%*wl_span at each end,
                    # we don't want intervals right to the end of the wl_range
                    if wl_range[0]+0.01*wl_span < _r[1] and wl_range[1]-0.01*wl_span > _r[0] :
                        # Some part of this _r is contained in wl_range
                        wl_range_to_include=[np.max([wl_range[0]+0.01*wl_span,_r[0]]),np.min([wl_range[1]-0.01*wl_span,_r[1]])]
                        wl_range_to_include_span=wl_range_to_include[1]-wl_range_to_include[0]
                        _r_span=_r[1]-_r[0]
                        if (
                            ( wl_range_to_include_span > wl_resolution*100. )
                        ) :
                            if M not in WL_ranges_to_include['species'] :
                                WL_ranges_to_include['species'].append(M)
                            wrti=[]
                            for _rr in WL_ranges_to_include['wl_ranges'] :
                                if wl_range_to_include[0] <= _rr[1] and wl_range_to_include[1] >= _rr[0] :
                                    wl_range_to_include=[np.min([wl_range_to_include[0],_rr[0]]),np.max([wl_range_to_include[1],_rr[1]])]
                                else :
                                    wrti.append(_rr)
                            wrti.append(wl_range_to_include)
                            WL_ranges_to_include['wl_ranges']=copy.deepcopy(wrti)
            wl_gaps=wl_exclude_gaps[extn]                
            if wl_gaps != [] :
                if len(np.shape(wl_gaps)) == 1 :
                    wl_gaps=[wl_gaps,]
                pm=0 # number of wl_ranges added or removed...
                for wl_gap in wl_gaps :
                    for (i,wlR) in enumerate(WL_ranges_to_include['wl_ranges']) :
                        if wlR[0] < wl_gap[0] and wlR[1] > wl_gap[0] :
                            # the WlR spans the gap...
                            # split the wlR into to...
                            if wlR[1] > wl_gap[1] :
                                WL_ranges_to_include['wl_ranges'].insert(i+1,copy.deepcopy(wlR))
                                WL_ranges_to_include['wl_ranges'][i+1][0]=wl_gap[1]
                            WL_ranges_to_include['wl_ranges'][i][1]=wl_gap[0]
                            pm+=1
                        elif wlR[0] > wl_gap[0] and wlR[1] < wl_gap[1] :
                            # the wlR is contained inside the gap...
                            # delete the wlR...
                            del WL_ranges_to_include['wl_ranges'][i]
                            pm+=-1
                        elif wlR[0] > wl_gap[0] and wlR[0] < wl_gap[1] and wlR[1] > wl_gap[1] :
                            WL_ranges_to_include['wl_ranges'][i][0]=wl_gap[1]
                        elif wlR[0] < wl_gap[0] and wlR[1] > wl_gap[0] and wlR[1] < wl_gap[1] :
                            WL_ranges_to_include['wl_ranges'][i][1]=wl_gap[0]


            if 'WAVE_INCLUDE' in filenames.keys() :
                t=Table.read(filenames['WAVE_INCLUDE'], format='fits')
                t['LOWER_LIMIT']*=wl_scale
                t['UPPER_LIMIT']*=wl_scale
                t.sort('LOWER_LIMIT')
                # Insure the wl-ranges in the table are included in the already defined
                # ranges, and if not add them...
                WL_ranges_to_include['wl_include_table']=[]
                WL_ranges_to_include['wl_include_table_valid']=[]
                for r in t :
                    WL_ranges_to_include['wl_include_table']+=[[r['LOWER_LIMIT'],r['UPPER_LIMIT']]]
                    # consider the wl_range trimmed by 1%*wl_span at each end,
                    # we don't want intervals right to the end of the wl_range
                    if (
                        ( wl_range[0]+0.01*wl_span <= r['LOWER_LIMIT'] and r['LOWER_LIMIT'] < wl_range[1]-0.01*wl_span )
                        or
                        ( wl_range[0]+0.01*wl_span <= r['UPPER_LIMIT'] and r['UPPER_LIMIT'] < wl_range[1]-0.01*wl_span )
                    ) :
                        wl_range_to_include=[np.max([wl_range[0]+0.01*wl_span,r['LOWER_LIMIT']]),np.min([wl_range[1]-0.01*wl_span,r['UPPER_LIMIT']])]
                        wl_range_to_include_span=wl_range_to_include[1]-wl_range_to_include[0]
                        if (
                            ( wl_range_to_include_span > wl_resolution*100. )
                        ) :
                            WL_ranges_to_include['wl_include_table_valid']+=[wl_range_to_include]
                            r_included=False
                            r_pos=0
                            for (i,_r) in enumerate(WL_ranges_to_include['wl_ranges']) :
                                if _r[0] <= wl_range_to_include[0] :
                                    r_pos=i
                                    if wl_range_to_include[0] < _r[1] :
                                        r_included=True
                                        if _r[1] < wl_range_to_include[1] :
                                            WL_ranges_to_include['wl_ranges'][i][1]=wl_range_to_include[1]
                            if not r_included :
                                WL_ranges_to_include['wl_ranges'].insert(r_pos+1,wl_range_to_include)
                # And now set wl_excludes to exclude all but any *VALID* WAVE_INCLUDE wl-ranges
                if len(WL_ranges_to_include['wl_include_table_valid']) > 0 :
                    WL_ranges_to_include['wl_exclude_ranges']=copy.deepcopy(WL_ranges_to_include['wl_ranges'])
                    for r in t :
                        r_included=False
                        r_pos=0
                        for (i,_r) in enumerate(WL_ranges_to_include['wl_exclude_ranges']) :
                            if _r[0] < r['LOWER_LIMIT'] :
                                r_pos=i
                                if r['LOWER_LIMIT'] < _r[1] :
                                    if _r[1] > r['UPPER_LIMIT'] :
                                        WL_ranges_to_include['wl_exclude_ranges'].insert(r_pos+1,[r['UPPER_LIMIT'],_r[1]])
                                    WL_ranges_to_include['wl_exclude_ranges'][i][1]=r['LOWER_LIMIT']
            if 'WAVE_EXCLUDE' in filenames.keys() :
                t=Table.read(filenames['WAVE_EXCLUDE'], format='fits')
                t['LOWER_LIMIT']*=wl_scale
                t['UPPER_LIMIT']*=wl_scale
                t.sort('LOWER_LIMIT')
                if 'wl_exclude_ranges' not in WL_ranges_to_include.keys() :
                    WL_ranges_to_include['wl_exclude_ranges']=[]
                for r in t :
                    if (
                        ( wl_range[0] <= r['LOWER_LIMIT'] and r['LOWER_LIMIT'] < wl_range[1] )
                        or
                        ( wl_range[0] <= r['UPPER_LIMIT'] and r['UPPER_LIMIT'] < wl_range[1] )
                    ) :
                        r_included=False
                        r_pos=0
                        for (i,_r) in enumerate(WL_ranges_to_include['wl_exclude_ranges']) :
                            if (
                                ( _r[0] <= r['LOWER_LIMIT'] and r['LOWER_LIMIT'] < _r[1] )
                                or
                                ( _r[0] <= r['UPPER_LIMIT'] and r['UPPER_LIMIT'] < _r[1] )
                            ) :
                                r_pos=i
                                r_included=True
                                WL_ranges_to_include['wl_exclude_ranges'][i]=[
                                    np.min([_r[0],r['LOWER_LIMIT']]),
                                    np.max([_r[1],r['UPPER_LIMIT']]),
                                ]
                                if r['LOWER_LIMIT'] < _r[1] :
                                    if _r[1] > r['UPPER_LIMIT'] :
                                        WL_ranges_to_include['wl_exclude_ranges'].insert(r_pos+1,[r['UPPER_LIMIT'],_r[1]])
                                    WL_ranges_to_include['wl_exclude_ranges'][i][1]=r['LOWER_LIMIT']
                        if not r_included :
                            WL_ranges_to_include['wl_exclude_ranges'].insert(r_pos+1,[r['LOWER_LIMIT'],r['UPPER_LIMIT']])
            if multi_chip :
                if 'wl_range_multi_chip' not in WL_ranges_to_include :
                    WL_ranges_to_include['wl_range_multi_chip']=[]
                WL_ranges_to_include['MAP_REGIONS_TO_CHIP']+=list(
                    np.ones(
                        #len(WL_ranges_to_include['wl_ranges'])-len(WL_ranges_to_include['MAP_REGIONS_TO_CHIP']),
                        len(WL_ranges_to_include['wl_ranges']),
                        dtype=int
                    )*(extn+1)
                )
                WL_ranges_to_include['wl_range_multi_chip']+=WL_ranges_to_include['wl_ranges']
                pass

    if multi_chip :
        WL_ranges_to_include['wl_ranges']=np.array(WL_ranges_to_include['wl_range_multi_chip'])[np.argsort([x[0] for x in WL_ranges_to_include['wl_range_multi_chip']])]
        WL_ranges_to_include['MAP_REGIONS_TO_CHIP']=np.array(WL_ranges_to_include['MAP_REGIONS_TO_CHIP'])[np.argsort([x[0] for x in WL_ranges_to_include['wl_range_multi_chip']])]

    if 'PIXEL_EXCLUDE' in filenames.keys() :
        t=Table.read(filenames['PIXEL_EXCLUDE'], format='fits')
        t.sort('LOWER_LIMIT')
        WL_ranges_to_include['pixel_exclude_ranges']=[]
        for r in t :
            WL_ranges_to_include['pixel_exclude_ranges']+=[[r['LOWER_LIMIT'],r['UPPER_LIMIT']]]

    # PIPE-9556 [PIPE9556]
    # Temporary hack to limit wl_ranges to a single interval because 4.1.0b
    # can't handle more than one interval...
    if use_PIPE9556_workaround and 'wl_ranges' in WL_ranges_to_include.keys() :
        if len(WL_ranges_to_include['wl_ranges']) > 1 :
            WL_ranges_to_include['wl_ranges']=[
                WL_ranges_to_include['wl_ranges'][0]
            ]
            #WL_ranges_to_include['wl_ranges']=[
            #    list(np.sort(np.array(WL_ranges_to_include['wl_ranges']).flatten())[[0,-1]])
            #]

    WL_ranges_to_include['wave_exclude']="NULL"
    WL_ranges_to_include['pixel_exclude']="NULL"
    if len(WL_ranges_to_include['wl_ranges']) < 1 :
        # Fall-backs in case we could not find any reasonable WL ranges...
        # These are not unreasonable in the VIS range...
        WL_ranges_to_include['list_mol']="H2O,O2"
        WL_ranges_to_include['fit_mol']="1,1"
        WL_ranges_to_include['rel_col']="1.0,1.0"
        WL_ranges_to_include['wave_include']=",".join(
            ["%f,%f" %(
                    (wl_range[0]+0.01*wl_span)/wl_scale,
                    (wl_range[1]-0.01*wl_span)/wl_scale,
                ) for wl_range in wl_ranges
            ]
        )
        WL_ranges_to_include['MAP_REGIONS_TO_CHIP']=",".join("%d" %(x+1) for x in range(len(wl_ranges)))
        '''
        for _p in wave_inc_opt_params.keys() :
            if _p in WL_ranges_to_include.keys() :
                if len(WL_ranges_to_include[_p]) > 0 :
                    WL_ranges_to_include[_p]=",".join("%d" %(x) for x in WL_ranges_to_include[_p])
        '''
    else :
        WL_ranges_to_include['list_mol']=",".join(WL_ranges_to_include['species'])
        if 'fit_mol' not in WL_ranges_to_include.keys() :
            WL_ranges_to_include['fit_mol']=",".join(["%d" %(_x) for _x in np.ones(len(WL_ranges_to_include['species']))])
        if 'rel_col' not in WL_ranges_to_include.keys() :
            WL_ranges_to_include['rel_col']=",".join(
                ["%f" %(_x) for _x in np.ones(len(WL_ranges_to_include['species']))]
            )
        WL_ranges_to_include['wave_include']=",".join(
            ["%f,%f" %(_x[0]/wl_scale,_x[1]/wl_scale) for _x in np.sort(WL_ranges_to_include['wl_ranges'],axis=0)]
        )
        if 'wl_exclude_ranges' in WL_ranges_to_include.keys() :
            WL_ranges_to_include['wave_exclude']=",".join(
                ["%f,%f" %(_x[0]/wl_scale,_x[1]/wl_scale)
                    for _x in np.sort(WL_ranges_to_include['wl_exclude_ranges'],axis=0)
                ]
           )
        if 'pixel_exclude_ranges' in WL_ranges_to_include.keys() :
            WL_ranges_to_include['pixel_exclude']=",".join(
                ["%d,%d" %(_x[0],_x[1])
                    for _x in np.sort(WL_ranges_to_include['pixel_exclude_ranges'],axis=0)
                ]
           )
        for _p in wave_inc_opt_params.keys() :
            if _p in WL_ranges_to_include.keys() :
                if len(WL_ranges_to_include[_p]) > 0 :
                    WL_ranges_to_include[_p]=",".join("%d" %(x) for x in WL_ranges_to_include[_p])

    return WL_ranges_to_include
# ------------------------------------------------------------------------------------------
