#!/usr/bin/python
# -*- coding: utf-8 -*-

# newclass.py
from __future__ import absolute_import
from __future__ import print_function
from astropy.io import fits
import string
from reflex import *
import json
import numpy as np
import sys
import wx
import matplotlib as mp
from pylab import *
from matplotlib.figure import Figure
from matplotlib.backends.backend_wxagg import FigureCanvasWxAgg as FigureCanvas
import readFits
from os import walk
import wx.lib.scrolledpanel

class FrameCalEstimate(wx.Frame):

    def __init__(self, *args, **kwds):    
        for key,value in kwds.items():
            if key == 'file':
                self.calib_in=value['calib']
                self.target_in=value['target']
            if key == 'sof':
                self.sof=value
        del kwds['file']
        del kwds['sof']

        self.calibNames = []
        self.targetNames = []
        for i in range(0,len(self.calib_in)):
            fitsFileName=self.calib_in[i]        
            f = fits.open(fitsFileName) 
            prihdr=f[0].header
            self.calibNames.append(prihdr['HIERARCH ESO OBS TARG NAME'])
            f.close()
        for i in range(0,len(self.target_in)):
            fitsFileName=self.target_in[i]
            f = fits.open(fitsFileName) 
            prihdr=f[0].header
            self.targetNames.append(prihdr['HIERARCH ESO OBS TARG NAME'])
            f.close()

        self.mode=''
        self.detector=''
        self.resolution=''
        self.spectralfilter=''
        self.polarizer=''
        self.spatialfilter=''
        self.dit=''
        
        print("TARGET FILENAME:",self.target_in[0])
        hdulist=fits.open(self.target_in[0])
        prihdr=hdulist[0].header
        self.detector=prihdr['HIERARCH ESO DET CHIP NAME']
        self.dit=prihdr['HIERARCH ESO DET SEQ1 DIT']
        pil=prihdr['HIERARCH ESO INS PIL ID']
        pin=prihdr['HIERARCH ESO INS PIN ID']
        if (pil == 'PHOTO' and pin == 'PHOTO') :
            self.mode='SIPHOT'
        if (pil == 'PHOTO' and pin == 'INTER') :
            self.mode='HYBRID'
        if (pil == 'INTER' and pin == 'INTER') :
            self.mode='HIGHSENS'
        if (self.detector == 'HAWAII-2RG'):
            self.resolution=prihdr['HIERARCH ESO INS DIL ID']
            self.spectralfilter=prihdr['HIERARCH ESO INS FIL ID']
            self.polarizer=prihdr['HIERARCH ESO INS POL ID']
            self.spatialfilter=prihdr['HIERARCH ESO INS SFL ID']
        else:
            self.resolution=prihdr['HIERARCH ESO INS DIN ID']
            self.spectralfilter=prihdr['HIERARCH ESO INS FIN ID']
            self.polarizer=prihdr['HIERARCH ESO INS PON ID']
            self.spatialfilter=prihdr['HIERARCH ESO INS SFN ID']
        hdulist.close()


        self.listBaseline()
        print("Baseline:",self.baseline)
        self.listTriplet()
        print("Triplet:",self.triplet)
        self.getMjdLimit()
        print("MJD LIMIT:",self.mjd0,self.mjd1)
        self.getListStation()
        print("Station:",self.station)
        self.loadDataTf()
        self.loadDataCp()
        self.idxBadCalib=np.zeros((len(self.calib_in)),dtype=np.int32)
        self.InitUI(*args, **kwds)
        self.Centre()
        self.Show()     

    def InitUI(self, *args, **kwds):
        wx.Frame.__init__(self, *args, **kwds)
      
        #panel = wx.Panel(self)
        panel = wx.lib.scrolledpanel.ScrolledPanel(self, -1, size=(400, 200))        
        panel.SetupScrolling()
        
        sizer = wx.GridBagSizer(5, 4)


    # CONFIGURATION ZONE
        text1 = wx.StaticText(panel, label="Configuration")
        font = wx.Font(16, wx.DECORATIVE, wx.NORMAL, wx.BOLD)
        text1.SetFont(font) 
        sizer.Add(text1, pos=(0, 0), flag=wx.TOP|wx.LEFT|wx.BOTTOM, 
            border=15)
        strStation=''
        for i in range(len(self.station)):
            strStation+=self.telesc[self.station[i]]
            if (i < len(self.station)-1):
                strStation+='-'
        gridBox=wx.GridSizer(1,6,10,10)
        str0=' '
        str1='MODE : '+self.mode+'\nDETECTOR : '+self.detector+'\nDIT : '+str(self.dit)+'s'+'\nTELESCOPES :'+strStation
        str2='SPECTRAL RESOLUTION : '+self.resolution+'\nPOLARIZER : '+self.polarizer+'\nSPATIAL FILTER : '+self.spatialfilter
        textConf0 = wx.StaticText(panel, -1 , str0)
        font = wx.Font(8, wx.NORMAL, wx.NORMAL, wx.NORMAL)
        textConf0.SetFont(font) 
        gridBox.Add(textConf0,0, 
                    flag=wx.EXPAND)
        textConf1 = wx.StaticText(panel, -1 , str1)
        font = wx.Font(8, wx.NORMAL, wx.NORMAL, wx.NORMAL)
        textConf1.SetFont(font) 
        gridBox.Add(textConf1,0, 
                    flag=wx.EXPAND)
        textConf2 = wx.StaticText(panel, -1 , str2)
        textConf2.SetFont(font) 
        gridBox.Add(textConf2,0, 
                  flag=wx.EXPAND)
        sizer.Add(gridBox,pos=(0,1),span=(1,3),
                  flag=wx.EXPAND, 
                  border=20)


        line = wx.StaticLine(panel)
        sizer.Add(line, pos=(2, 0), span=(1, 4), flag=wx.EXPAND|wx.BOTTOM, border=10)

    # Transfer Function ZONE
        text4 = wx.StaticText(panel, label="Transfer Function")
        font = wx.Font(16, wx.DECORATIVE, wx.NORMAL, wx.BOLD)
        text4.SetFont(font) 
        sizer.Add(text4, pos=(3, 0), flag=wx.LEFT, border=10)
  
        vboxFigTf = wx.GridSizer(2,3,10,10)

        self.listFigTf=[]
        self.listCanvasTf=[]
        self.listCid=[]

        for j in range(6):
            self.listFigTf.append(wx.GridSizer(2,3,10,10))
            self.listFigTf[j]=Figure(figsize=(4,2))
            self.listFigTf[j].set_facecolor('white')
            self.listCanvasTf.append(FigureCanvas(panel,-1,self.listFigTf[j]))
            if (j==0):
                self.listCid.append(self.listCanvasTf[j].mpl_connect('button_press_event',self.onClickTF1))

            vboxFigTf.Add(self.listCanvasTf[j],0,wx.EXPAND|wx.CENTER)

        sizer.Add(vboxFigTf, pos=(4, 0), span=(15,4),
                  flag=wx.EXPAND, border=0)


        line = wx.StaticLine(panel)
        sizer.Add(line, pos=(20, 0), span=(1, 4), flag=wx.EXPAND|wx.BOTTOM, border=10)

    # Phase Closure ZONE
        text5 = wx.StaticText(panel, label="Closure Phase")
        font = wx.Font(16, wx.DECORATIVE, wx.NORMAL, wx.BOLD)
        text5.SetFont(font) 
        sizer.Add(text5, pos=(21, 0), flag=wx.LEFT, border=10)
  
        vboxFigCp = wx.GridSizer(1,3,10,10)

        self.listFigCp=[]
        self.listCanvasCp=[]

        for j in range(3):
            self.listFigCp.append(Figure(figsize=(4,2)))
            self.listFigCp[j].set_facecolor('white')
            self.listCanvasCp.append(FigureCanvas(panel,-1,self.listFigCp[j]))
            vboxFigCp.Add(self.listCanvasCp[j],0,wx.EXPAND|wx.CENTER)

        sizer.Add(vboxFigCp, pos=(22, 0), span=(6,4),
                  flag=wx.EXPAND, border=0)
    # BOTTOM BUTTONS
        line = wx.StaticLine(panel)
        sizer.Add(line, pos=(29, 0), span=(1, 4), 
            flag=wx.EXPAND|wx.BOTTOM, border=10)

        hboxButton = wx.BoxSizer(wx.HORIZONTAL)
        buttonCont = wx.Button(panel, label="Continue")
        #buttonQuit = wx.Button(panel, label="Quit")
        hboxButton.Add(buttonCont,0,wx.CENTER)
        #hboxButton.Add(buttonQuit,0,wx.CENTER)
        sizer.Add(hboxButton, pos=(30, 3),   
                  flag=wx.CENTER, border=5)

        sizer.AddGrowableCol(2)
        
        panel.SetSizer(sizer)

        buttonCont.Bind(wx.EVT_BUTTON, self.buttonContClick)
        #buttonQuit.Bind(wx.EVT_BUTTON, self.buttonQuitClick)
        self.drawTf()
        self.drawCp()


    def onClickTF1(self,event):
        
        if (event.button==1):
            if (event.xdata != None):
                cpt=self.getNearestCalib(event.xdata)
                self.idxBadCalib[cpt]=1
                for j in range(0,self.nbBlockCalib[cpt]):
                    self.listCurve[self.indexListCurve[j,cpt]][0].set_color('grey')
                
                for i in range(6):
                    if (self.flagFitDone == 1):
                        self.listCurveFitTF[i][0].remove()
                        self.listCurveFitTFPlus[i][0].remove()
                        self.listCurveFitTFMinus[i][0].remove()
                    self.listCanvasTf[i].draw()
                #self.fitTf()
                for i in range(6):
                    self.listCanvasTf[i].draw()

                for j in range(0,self.nbBlockCpCalib[cpt]):
                    self.listCurveCp[self.indexListCurveCp[j,cpt]][0].set_color('grey')
                for i in range(3):
                    if (self.flagFitCpDone == 1):
                        self.listCurveFitCp[i][0].remove()
                        self.listCurveFitCpPlus[i][0].remove()
                        self.listCurveFitCpMinus[i][0].remove()
                    self.listCanvasCp[i].draw()
                #self.fitCp()
                for i in range(3):
                    self.listCanvasCp[i].draw()
        if (event.button==3):
            if (event.xdata != None):
                cpt=self.getNearestCalib(event.xdata)
                self.idxBadCalib[cpt]=0
                for j in range(0,self.nbBlockCalib[cpt]):
                    self.listCurve[self.indexListCurve[j,cpt]][0].set_color('blue')
                for i in range(6):
                    if (self.flagFitDone == 1):
                        self.listCurveFitTF[i][0].remove()
                        self.listCurveFitTFPlus[i][0].remove()
                        self.listCurveFitTFMinus[i][0].remove()
                    self.listCanvasTf[i].draw()
                #self.fitTf()
                for i in range(6):
                    self.listCanvasTf[i].draw()
                    
                for j in range(0,self.nbBlockCpCalib[cpt]):
                    self.listCurveCp[self.indexListCurveCp[j,cpt]][0].set_color('blue')
                for i in range(3):
                    if (self.flagFitCpDone == 1):
                        self.listCurveFitCp[i][0].remove()
                        self.listCurveFitCpPlus[i][0].remove()
                        self.listCurveFitCpMinus[i][0].remove()
                    self.listCanvasCp[i].draw()
                #self.fitCp()
                for i in range(3):
                    self.listCanvasCp[i].draw()

    def getNearestCalib(self,t):
        valmin=1.E5
        for i in range(0,len(self.calib_in)):
            for j in range(0,self.nbBlockCalib[i]):
                if (np.abs(t-self.valTime[j,i]) < valmin):
                    valmin=np.abs(t-self.valTime[j,i])
                    idx=i
        return idx

    def buttonContClick(self,event):
        for j in range(6):
            self.listFigTf[j].clf()
        for j in range(3):
            self.listFigCp[j].clf()

        cpt=0
        for fic in self.calib_in:
            if (self.idxBadCalib[cpt] == 1):
                hdu=fits.open(fic)
                name=hdu['OI_TARGET'].data['TARGET'][0]
                hdu.close()
                i = len(self.sof.files)
                while i:
                    hdu=fits.open(self.sof.files[i-1].name) 
                    prihdr=hdu[0].header     
                    if str(hdu['OI_TARGET'].data['TARGET'][0]) == name :
                        del self.sof.files[i-1]
                    i = i-1
                hdu.close()
            cpt+=1

        self.Close()
        event.Skip()

    #def buttonQuitClick(self,event):
    #    self.Destroy()

    def getMjdLimit(self):
        self.mjd0=1.E5
        self.mjd1=-1.E5
        for i in range(0,len(self.calib_in)):
            hdu = fits.open(self.calib_in[i])
            mjd=hdu['TF2'].data['MJD']
            nbrow=len(mjd)
            for j in range(0,nbrow):
                if (mjd[j] < self.mjd0):
                    self.mjd0=mjd[j]
                if (mjd[j] > self.mjd1):
                    self.mjd1=mjd[j]
            hdu.close()
        for i in range(0,len(self.target_in)):
            hdu = fits.open(self.target_in[i])
            mjd=hdu['OI_VIS2'].data['MJD']
            nbrow=len(mjd)
            for j in range(0,nbrow):
                if (mjd[j] < self.mjd0):
                    self.mjd0=mjd[j]
                if (mjd[j] > self.mjd1):
                    self.mjd1=mjd[j]
            hdu.close()
        
    def listBaseline(self):
        hdu = fits.open(self.calib_in[0]) 
        listStation=hdu['OI_VIS2'].data['STA_INDEX']
        hdu.close()
        nbrow=len(listStation)
        nbBase=0
        self.baseline=[]
        for i in range(nbrow):
            flag=1
            for j in range(nbBase):
                if ( (listStation[i][0] == listStation[j][0] and listStation[i][1] == listStation[j][1]) or (listStation[i][0] == listStation[j][1] and listStation[i][1] == listStation[j][0]) ):
                    flag=0
            if (flag == 1):
                self.baseline.append([listStation[i][0],listStation[i][1]])
                nbBase=nbBase+1
    
    def listTriplet(self):
        hdu = fits.open(self.calib_in[0]) 
        listStation=hdu['OI_T3'].data['STA_INDEX']
        hdu.close()
        nbrow=len(listStation)
        nbTriplet=0
        self.triplet=[]
        for i in range(nbrow):
            flag=1
            for j in range(nbTriplet):
                if ( (listStation[i][0] == listStation[j][0] and listStation[i][1] == listStation[j][1] and listStation[i][2] == listStation[j][2]) ):
                    flag=0
            if (flag == 1):
                self.triplet.append([listStation[i][0],listStation[i][1],listStation[i][2]])
                nbTriplet=nbTriplet+1

    def defineIndexBaseline(self,station):
        nbBaseline=len(self.baseline)
        self.indexBaseline=-1
        for i in range(nbBaseline):
            if ( (station[0] == self.baseline[i][0] and station[1] == self.baseline[i][1]) or (station[0] == self.baseline[i][1] and station[1] == self.baseline[i][0]) ):
                self.indexBaseline=i

    def defineIndexTriplet(self,station):
        nbTriplet=len(self.triplet)
        self.indexTriplet=-1
        for i in range(nbTriplet):
            if ( (station[0] == self.triplet[i][0] and station[1] == self.triplet[i][1] and station[2] == self.triplet[i][2]) ):
                self.indexTriplet=i

    def getListStation(self):
        hdu = fits.open(self.calib_in[0]) 
        telname=hdu['OI_ARRAY'].data['STA_NAME']
        nbrow=len(telname)
        self.telesc={}
        self.station={}
        #base name dictionary
        for i in range(0, nbrow):
            self.telesc[hdu['OI_ARRAY'].data['STA_INDEX'][i]]=telname[i]
            self.station[i]=hdu['OI_ARRAY'].data['STA_INDEX'][i]
        hdu.close()


    def loadDataTf(self):
        nrowMax=600
        self.valTime=np.zeros((nrowMax,len(self.calib_in)),dtype=np.float32)
        self.valTF=np.zeros((nrowMax,len(self.calib_in)),dtype=np.float32)
        self.valTFerr=np.zeros((nrowMax,len(self.calib_in)),dtype=np.float32)
        self.stationTF=np.zeros((nrowMax,2,len(self.calib_in)),dtype=np.int32)
        self.nbBlockCalib=np.zeros(len(self.calib_in),dtype=np.int32)

        self.valTimeVis2=np.zeros((nrowMax,len(self.target_in)),dtype=np.float32)
        self.valVis2=np.zeros((nrowMax,len(self.target_in)),dtype=np.float32)
        self.valVis2err=np.zeros((nrowMax,len(self.target_in)),dtype=np.float32)
        self.stationVis2=np.zeros((nrowMax,2,len(self.calib_in)),dtype=np.int32)
        self.nbBlockTarget=np.zeros(len(self.target_in),dtype=np.int32)
        
        for i in range(0,len(self.calib_in)):
            hdu = fits.open(self.calib_in[i])
            station=hdu['TF2'].data['STA_INDEX']
            tf2=hdu['TF2'].data['TF2']
            tf2Err=hdu['TF2'].data['TF2ERR']
            mjd=hdu['TF2'].data['MJD']
            tf2Mean=np.mean(tf2,axis=1)
            tf2ErrMean=np.median(tf2Err,axis=1)
            
            self.nbrow=len(tf2Mean)
            self.nbBlockCalib[i]=self.nbrow
            
            self.valTime[0:self.nbrow,i]=(mjd[0:self.nbrow]-self.mjd0)*24.
            self.valTF[0:self.nbrow,i]=tf2Mean[0:self.nbrow]
            self.valTFerr[0:self.nbrow,i]=tf2ErrMean[0:self.nbrow]
            self.stationTF[0:self.nbrow,0:2,i]=station[0:self.nbrow,0:2]

        for i in range(0,len(self.target_in)):
            hdu = fits.open(self.target_in[i])
            station=hdu['OI_VIS2'].data['STA_INDEX']
            vis2=hdu['OI_VIS2'].data['VIS2DATA']
            vis2Err=hdu['OI_VIS2'].data['VIS2ERR']
            mjd=hdu['OI_VIS2'].data['MJD']
            vis2Mean=np.mean(vis2,axis=1)
            vis2ErrMean=np.median(vis2Err,axis=1)
            
            self.nbrow=len(vis2Mean)
            self.nbBlockTarget[i]=self.nbrow

            self.valTimeVis2[0:self.nbrow,i]=(mjd[0:self.nbrow]-self.mjd0)*24.
            self.valVis2[0:self.nbrow,i]=vis2Mean[0:self.nbrow]
            self.valVis2err[0:self.nbrow,i]=vis2ErrMean[0:self.nbrow]
            self.stationVis2[0:self.nbrow,0:2,i]=station[0:self.nbrow,0:2]


    def drawTf(self):
        nrowMax=600
        self.polytf=np.zeros((nrowMax),dtype=np.float32)
        self.polytferr=np.zeros((nrowMax),dtype=np.float32)
        self.polytime=np.zeros((nrowMax),dtype=np.float32)
        self.nrowpoly=0
        
        for j in range(6):
            self.listFigTf[j].clf()
        self.listAxesTf=[]
        for j in range(6):
            self.listAxesTf.append(self.listFigTf[j].add_subplot(111))

        self.listCurve=[]
        self.indexListCurve=np.zeros((nrowMax,len(self.calib_in)),dtype=np.int32)
        flagNameCalib=np.zeros((6,len(self.calib_in)),dtype=np.int32)
        cpt=0
        xmin=np.min(self.valTime)-0.5
        xmax=np.max(self.valTime)+0.5
        for i in range(0,len(self.calib_in)):
            hdu = fits.open(self.calib_in[i])
            for j in range(0,self.nbBlockCalib[i]):
                self.defineIndexBaseline(self.stationTF[j,:,i])
                idx=self.indexBaseline
                self.listAxesTf[idx].set_ylabel('Squared Visibility')
                self.listAxesTf[idx].set_ylim(-0.1,1.0)
                #self.listAxesTf[idx].set_xlim(-2.5,self.valTime[j,i]+2.5)
                self.listAxesTf[idx].set_xlim(xmin,xmax)
                self.listAxesTf[idx].yaxis.label.set_fontsize(8)
                self.listAxesTf[idx].set_xlabel('Time (in hour)')
                self.listAxesTf[idx].xaxis.label.set_fontsize(8)
                self.listAxesTf[idx].set_title('Baseline '+self.telesc[hdu['TF2'].data['STA_INDEX'][j][0]]+self.telesc[hdu['TF2'].data['STA_INDEX'][j][1]])
                self.listAxesTf[idx].title.set_fontsize(8)
                for item in (self.listAxesTf[idx].get_xticklabels() + self.listAxesTf[idx].get_yticklabels()):
                    item.set_fontsize(6)
                self.listCurve.append(self.listAxesTf[idx].errorbar(self.valTime[j,i],self.valTF[j,i],self.valTFerr[j,i],fmt='', color='blue', marker='o', linestyle='',markersize=4))
                self.indexListCurve[j,i]=cpt

                if (flagNameCalib[idx,i] == 0):
                    if (self.valTF[j,i] > 0.5):
                        self.listAxesTf[idx].text(self.valTime[j,i],0.4,hdu['OI_TARGET'].data['TARGET'][0],fontsize=7,color='blue',rotation=90)
                    else:
                        self.listAxesTf[idx].text(self.valTime[j,i],0.9,hdu['OI_TARGET'].data['TARGET'][0],fontsize=7,color='blue',rotation=90)
                    flagNameCalib[idx,i]=1

                cpt+=1
   
                self.listFigTf[idx].tight_layout()
                self.listCanvasTf[idx].draw()
            hdu.close()

        #self.fitTf()

        flagNameTarget=np.zeros((6,len(self.calib_in)),dtype=np.int32)
        for i in range(0,len(self.target_in)):
            hdu = fits.open(self.target_in[i])
            for j in range(0,self.nbBlockTarget[i]):
                self.defineIndexBaseline(self.stationVis2[j,:,i])
                idx=self.indexBaseline
                self.listAxesTf[idx].errorbar(self.valTimeVis2[j,i],self.valVis2[j,i],self.valVis2err[j,i],fmt='', color='red', marker='o', linestyle='',markersize=4)
                if (flagNameTarget[idx,i] == 0):
                    if (self.valVis2[j,i] > 0.5):
                        self.listAxesTf[idx].text(self.valTimeVis2[j,i],0.4,hdu['OI_TARGET'].data['TARGET'][0],fontsize=7,color='red',rotation=90)
                    else:
                        self.listAxesTf[idx].text(self.valTimeVis2[j,i],0.9,hdu['OI_TARGET'].data['TARGET'][0],fontsize=7,color='red',rotation=90)
                    flagNameTarget[idx,i]=1   
                self.listFigTf[idx].tight_layout()
                self.listCanvasTf[idx].draw()
                
            hdu.close()


    def fitTf(self):
        self.flagFitDone=0
        self.listCurveFitTF=[]
        self.listCurveFitTFPlus=[]
        self.listCurveFitTFMinus=[]
        self.listCoefPolynome=[]
        self.listErrPolynome=[]
        if (self.idxBadCalib[:].sum() > (len(self.calib_in)-2)):
            print("Not enough Calibrators to fit a line")
            self.flagFitDone=0
        else:
            self.flagFitDone=1
            x=(25.*np.arange(100)/100.)-1.
            for k in range(0,6):
                cpt=0
                self.nrowpoly=0
                for i in range(0,len(self.calib_in)):
                    hdu = fits.open(self.calib_in[i])
                    for j in range(0,self.nbBlockCalib[i]):
                        self.defineIndexBaseline(self.stationTF[j,:,i])
                        idx=self.indexBaseline
                        if ( (idx == k) and (self.idxBadCalib[i] == 0) ):
                            self.polytf[cpt]=self.valTF[j,i]
                            self.polytferr[cpt]=self.valTFerr[j,i]
                            self.polytime[cpt]=self.valTime[j,i]
                            self.nrowpoly+=1
                            cpt+=1
                    hdu.close()

                z=np.polyfit(self.polytime[0:self.nrowpoly],self.polytf[0:self.nrowpoly],1,w=1./(1.E-3+self.polytferr[0:self.nrowpoly]))
                self.listCoefPolynome.append(np.poly1d(z))
                self.listErrPolynome.append(np.median(self.polytferr[0:self.nrowpoly]))
                self.listCurveFitTF.append(self.listAxesTf[k].plot(x,self.listCoefPolynome[k](x),color='black',linewidth=0.5))
                self.listCurveFitTFPlus.append(self.listAxesTf[k].plot(x,self.listCoefPolynome[k](x)-self.listErrPolynome[k],'--',color='black',linewidth=0.5))
                self.listCurveFitTFMinus.append(self.listAxesTf[k].plot(x,self.listCoefPolynome[k](x)+self.listErrPolynome[k],'--',color='black',linewidth=0.5))

           
    def fitCp(self):
        self.flagFitCpDone=0
        self.listCurveFitCp=[]
        self.listCurveFitCpPlus=[]
        self.listCurveFitCpMinus=[]
        self.listCoefPolynomeCp=[]
        self.listErrPolynomeCp=[]
        if (self.idxBadCalib[:].sum() > (len(self.calib_in)-2)):
            print("Not enough Calibrators to fit a line")
        else:
            self.flagFitCpDone=1
            x=(25.*np.arange(100)/100.)-1.
            for k in range(0,3):
                cpt=0
                self.nrowpolycp=0
                for i in range(0,len(self.calib_in)):
                    hdu = fits.open(self.calib_in[i])
                    for j in range(0,self.nbBlockCpCalib[i]):
                        self.defineIndexTriplet(self.stationCp[j,:,i])
                        idx=self.indexTriplet
                        if (idx < 3):
                            if ( (idx == k) and (self.idxBadCalib[i] == 0) ):
                                self.polycp[cpt]=self.valCp[j,i]
                                self.polycperr[cpt]=self.valCperr[j,i]
                                self.polycptime[cpt]=self.valTimeCp[j,i]
                                self.nrowpolycp+=1
                                cpt+=1
                    hdu.close()

                
                z=np.polyfit(self.polycptime[0:self.nrowpolycp],self.polycp[0:self.nrowpolycp],1,w=1./(self.polycperr[0:self.nrowpolycp]))
                self.listCoefPolynomeCp.append(np.poly1d(z))
                self.listErrPolynomeCp.append(np.median(self.polycperr[0:self.nrowpolycp]))
                self.listCurveFitCp.append(self.listAxesCp[k].plot(x,self.listCoefPolynome[k](x),color='black',linewidth=0.5))
                self.listCurveFitCpPlus.append(self.listAxesCp[k].plot(x,self.listCoefPolynomeCp[k](x)-self.listErrPolynomeCp[k],'--',color='black',linewidth=0.5))
                self.listCurveFitCpMinus.append(self.listAxesCp[k].plot(x,self.listCoefPolynomeCp[k](x)+self.listErrPolynomeCp[k],'--',color='black',linewidth=0.5))
                                

    def loadDataCp(self):
        nrowMax=600
        self.valTimeCp=np.zeros((nrowMax,len(self.calib_in)),dtype=np.float32)
        self.valCp=np.zeros((nrowMax,len(self.calib_in)),dtype=np.float32)
        self.valCperr=np.zeros((nrowMax,len(self.calib_in)),dtype=np.float32)
        self.stationCp=np.zeros((nrowMax,3,len(self.calib_in)),dtype=np.int32)
        self.nbBlockCpCalib=np.zeros(len(self.calib_in),dtype=np.int32)

        self.valTimeCpTarget=np.zeros((nrowMax,len(self.target_in)),dtype=np.float32)
        self.valCpTarget=np.zeros((nrowMax,len(self.target_in)),dtype=np.float32)
        self.valCpTargeterr=np.zeros((nrowMax,len(self.target_in)),dtype=np.float32)
        self.stationCpTarget=np.zeros((nrowMax,3,len(self.calib_in)),dtype=np.int32)
        self.nbBlockCpTarget=np.zeros(len(self.calib_in),dtype=np.int32)
        
        for i in range(0,len(self.calib_in)):
            hdu = fits.open(self.calib_in[i])
            station=hdu['OI_T3'].data['STA_INDEX']
            phi=hdu['OI_T3'].data['T3PHI']
            phiErr=hdu['OI_T3'].data['T3PHIERR']
            mjd=hdu['OI_T3'].data['MJD']
            phiMean=np.mean(phi,axis=1)
            phiErrMean=np.median(phiErr,axis=1)
            
            self.nbrowCp=len(phiMean)
            self.nbBlockCpCalib[i]=self.nbrowCp

            self.valTimeCp[0:self.nbrowCp,i]=(mjd[0:self.nbrowCp]-self.mjd0)*24.
            self.valCp[0:self.nbrowCp,i]=phiMean[0:self.nbrowCp]
            self.valCperr[0:self.nbrowCp,i]=phiErrMean[0:self.nbrowCp]
            self.stationCp[0:self.nbrowCp,0:3,i]=station[0:self.nbrowCp,0:3]

        for i in range(0,len(self.target_in)):
            hdu = fits.open(self.target_in[i])
            station=hdu['OI_T3'].data['STA_INDEX']
            phi=hdu['OI_T3'].data['T3PHI']
            phiErr=hdu['OI_T3'].data['T3PHIERR']
            mjd=hdu['OI_T3'].data['MJD']
            phiMean=np.mean(phi,axis=1)
            phiErrMean=np.median(phiErr,axis=1)
            
            self.nbrowCp=len(phiMean)
            self.nbBlockCpTarget[i]=self.nbrowCp

            self.valTimeCpTarget[0:self.nbrowCp,i]=(mjd[0:self.nbrowCp]-self.mjd0)*24.
            self.valCpTarget[0:self.nbrowCp,i]=phiMean[0:self.nbrowCp]
            self.valCpTargeterr[0:self.nbrowCp,i]=phiErrMean[0:self.nbrowCp]
            self.stationCpTarget[0:self.nbrowCp,0:3,i]=station[0:self.nbrowCp,0:3]

    def drawCp(self):  
        nrowMax=600
        self.polycp=np.zeros((nrowMax),dtype=np.float32)
        self.polycperr=np.zeros((nrowMax),dtype=np.float32)
        self.polycptime=np.zeros((nrowMax),dtype=np.float32)
        self.nrowpolycp=0
        
        for j in range(3):
            self.listFigCp[j].clf()
        self.listAxesCp=[]
        for j in range(3):
            self.listAxesCp.append(self.listFigCp[j].add_subplot(111))

        self.listCurveCp=[]
        self.indexListCurveCp=np.zeros((nrowMax,len(self.calib_in)),dtype=np.int32)
        flagNameCalib=np.zeros((3,len(self.calib_in)),dtype=np.int32)
        cpt=0
        xmin=np.min(self.valTimeCp)-0.5
        xmax=np.max(self.valTimeCp)+0.5
        for i in range(0,len(self.calib_in)):
            hdu = fits.open(self.calib_in[i])
            for j in range(0,self.nbBlockCpCalib[i]):
                self.defineIndexTriplet(self.stationCp[j,:,i])
                print(self.stationCp[j,:,i],self.indexTriplet)
                idx=self.indexTriplet
                if (idx < 3):
                    self.listAxesCp[idx].set_ylabel('Closure Phase')
                    self.listAxesCp[idx].set_ylim(-180,180.0)
                    #self.listAxesCp[idx].set_xlim(-2.5,self.valTimeCp[j,i]+2.5)
                    self.listAxesCp[idx].set_xlim(xmin,xmax)
                    self.listAxesCp[idx].yaxis.label.set_fontsize(8)
                    self.listAxesCp[idx].set_xlabel('Time (in hour)')
                    self.listAxesCp[idx].xaxis.label.set_fontsize(8)
                    self.listAxesCp[idx].set_title('Triplet '+self.telesc[hdu['OI_T3'].data['STA_INDEX'][j][0]]+self.telesc[hdu['OI_T3'].data['STA_INDEX'][j][1]]+self.telesc[hdu['OI_T3'].data['STA_INDEX'][j][2]])
                    self.listAxesCp[idx].title.set_fontsize(8)
                    for item in (self.listAxesCp[idx].get_xticklabels() + self.listAxesCp[idx].get_yticklabels()):
                        item.set_fontsize(6)        
                    self.listCurveCp.append(self.listAxesCp[idx].errorbar(self.valTimeCp[j,i],self.valCp[j,i],self.valCperr[j,i],fmt='', color='blue', marker='o', linestyle='',markersize=4))
                    self.indexListCurveCp[j,i]=cpt
                    if (flagNameCalib[idx,i] == 0):
                        if (self.valCp[j,i] > 0.):
                            self.listAxesCp[idx].text(self.valTimeCp[j,i],-50.,hdu['OI_TARGET'].data['TARGET'][0],fontsize=7,color='blue',rotation=90)
                        else:
                            self.listAxesCp[idx].text(self.valTimeCp[j,i],160.,hdu['OI_TARGET'].data['TARGET'][0],fontsize=7,color='blue',rotation=90)
                        flagNameCalib[idx,i]=1
                    cpt+=1
                    self.listFigCp[idx].tight_layout()
                    self.listCanvasCp[idx].draw()
            hdu.close()

        #self.fitCp()

        flagNameTarget=np.zeros((3,len(self.calib_in)),dtype=np.int32)
        for i in range(0,len(self.target_in)):
            hdu = fits.open(self.target_in[i])
            for j in range(0,self.nbBlockCpTarget[i]):
                self.defineIndexTriplet(self.stationCp[j,:,i])
                idx=self.indexTriplet
                if (idx < 3):
                    self.listAxesCp[idx].errorbar(self.valTimeCpTarget[j,i],self.valCpTarget[j,i],self.valCpTargeterr[j,i],fmt='', color='red', marker='o', linestyle='',markersize=4)
                    if (flagNameTarget[idx,i] == 0):
                        if (self.valCp[j,i] > 0.):
                            self.listAxesCp[idx].text(self.valTimeCpTarget[j,i],-50.,hdu['OI_TARGET'].data['TARGET'][0],fontsize=7,color='red',rotation=90)
                        else:
                            self.listAxesCp[idx].text(self.valTimeCpTarget[j,i],160.,hdu['OI_TARGET'].data['TARGET'][0],fontsize=7,color='red',rotation=90)
                        flagNameTarget[idx,i]=1   
                    self.listFigCp[idx].tight_layout()
                    self.listCanvasCp[idx].draw()
            hdu.close()




if __name__ == '__main__':
  
    parser = ReflexIOParser()
    parser.add_input("--enable")
    parser.add_input("--sof_in")
    parser.add_output("--sof_out")
    #input files
    inputs = parser.get_inputs()
    outputs = parser.get_outputs()
    outputs.sof_out = inputs.sof_in

    if (inputs.enable == "true"):
# list all files which are OBJ_CORR_FLUX files    
        listCalib=[]
        listTarget=[]
        for i in range(0,len(inputs.sof_in.files)):
            if (inputs.sof_in.files[i].category == "CALIB_RAW_INT") :
                listCalib.append(inputs.sof_in.files[i].name)
            if (inputs.sof_in.files[i].category == "TARGET_RAW_INT") :
                listTarget.append(inputs.sof_in.files[i].name)
        listFiles={'calib':  listCalib, 'target': listTarget}

        app = wx.App()
        frameReflex=FrameCalEstimate(None, title="MATISSE WORKFLOW CAL ESTIMATE",size=(1500, 850),file=listFiles , sof=outputs.sof_out)
        app.SetTopWindow(frameReflex)
        app.MainLoop()

    parser.write_outputs()
    sys.exit()
