/*
 * This file is part of the QMOST Pipeline
 * Copyright (C) 2002-2022 European Southern Observatory
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

/*----------------------------------------------------------------------------*/
/*
 *                              Includes
 */
/*----------------------------------------------------------------------------*/

#include <cpl.h>

#include "qmost_gaincor.h"
#include "qmost_stats.h"
#include "qmost_utils.h"

/*----------------------------------------------------------------------------*/
/**
 * @defgroup qmost_gaincor  qmost_gaincor
 *
 * Gain correction for detector flats.
 *
 * @par Synopsis:
 * @code
 *   #include "qmost_gaincor.h"
 * @endcode
 */
/*----------------------------------------------------------------------------*/

/**@{*/

/*----------------------------------------------------------------------------*/
/*
 *                              Function prototypes
 */
/*----------------------------------------------------------------------------*/

cpl_error_code qmost_skydiff_subset (
    cpl_image *img,
    int xla,
    int yla,
    int xha,
    int yha,
    int xlb,
    int ylb,
    int xhb,
    int yhb,
    int flip,
    float *skydiff,
    float *noisdiff,
    float *wt);

/*----------------------------------------------------------------------------*/
/**
 * @brief   Gain correct the amps in a detector flat.
 *
 * The relative gains of the amps in a multi-amp readout detector flat
 * are determined by minimising the discontinuities in the flat count
 * levels across the amp boundaries.  These are determined by
 * differencing the flat levels in thin strips either side of the
 * boundary (of width or height specified by the stripsize argument)
 * and then solving the resulting set of simultaneous equations.
 *
 * The gain corrections are then applied to the given image and the
 * values used are recorded in DRS keywords in the fits header.
 *
 * @param   img              (Modified) The image to gain correct.
 *                                      Will be modified in place.
 *                                      The data type must be
 *                                      CPL_TYPE_FLOAT.
 * @param   hdr              (Modified) The extension FITS header for
 *                                      the image with the DRS
 *                                      keywords written by
 *                                      qmost_ccdproc.  Will be
 *                                      modified to add DRS headers
 *                                      giving the gain correction
 *                                      applied.
 * @param   stripsize        (Given)    Strip size to use when
 *                                      computing the overlap
 *                                      corrections.  Typically 16
 *                                      pixels.
 *
 * @return  cpl_error_code
 *
 * @retval  CPL_ERROR_NONE                If everything is OK.
 * @retval  CPL_ERROR_DATA_NOT_FOUND      If one of the required input
 *                                        FITS header keywords was not
 *                                        found.
 * @retval  CPL_ERROR_NULL_INPUT          If one of the required
 *                                        inputs was NULL.
 * @retval  CPL_ERROR_UNSUPPORTED_MODE    If the number of amps is not
 *                                        1 or an even number.
 * @retval  CPL_ERROR_TYPE_MISMATCH       If one of the required input
 *                                        FITS header keyword values
 *                                        had an incorrect data type.
 *
 * @par Input FITS Header Information:
 *   - <b>ESO DRS NAMPS</b>
 *
 * @par Output DRS Headers:
 *   - <b>AMPn GAINCOR</b>: The ratio of the gain (CONAD) of
 *     amplifier n to the gain of amplifier 1.
 *
 * @author  Jonathan Irwin, CASU
 */
/*----------------------------------------------------------------------------*/

cpl_error_code qmost_gaincor (
    cpl_image *img,
    cpl_propertylist *hdr,
    int stripsize)
{
    cpl_errorstate prestate;
    int nampsx, nampsy, namps;

    int nximg, nyimg;
    int nxamp, nyamp;

    int ieq, neq;

    cpl_image *tmpimg = NULL;
    float refsky, refnoise;

    cpl_matrix *p = NULL;
    cpl_matrix *q = NULL;
    cpl_matrix *r = NULL;

    int iampx, iampy, iampa, iampb;
    int xl, xm, xh, yl, ym, yh;
    float skydiff, noisdiff, wt;

    float *data;
    float relgain;
    int x, y;

    char *key = NULL;

    cpl_ensure_code(img != NULL, CPL_ERROR_NULL_INPUT);
    cpl_ensure_code(hdr != NULL, CPL_ERROR_NULL_INPUT);

#undef TIDY
#define TIDY                                    \
    if(tmpimg != NULL) {                        \
        cpl_image_delete(tmpimg);               \
        tmpimg = NULL;                          \
    }                                           \
    if(p != NULL) {                             \
        cpl_matrix_delete(p);                   \
        p = NULL;                               \
    }                                           \
    if(q != NULL) {                             \
        cpl_matrix_delete(q);                   \
        q = NULL;                               \
    }                                           \
    if(r != NULL) {                             \
        cpl_matrix_delete(r);                   \
        r = NULL;                               \
    }                                           \
    if(key != NULL) {                           \
        cpl_free(key);                          \
        key = NULL;                             \
    }

    /* Determine number of amps from image header */
    if(qmost_cpl_propertylist_get_int(hdr,
                                      "ESO DRS NAMPS",
                                      &namps) != CPL_ERROR_NONE) {
        TIDY;
        return cpl_error_set_message(cpl_func, cpl_error_get_code(),
                                     "could not read ESO DRS NAMPS "
                                     "from IMAGE extension header");
    }

    /* Check if there's any work to do */
    if(namps < 2) {
        /* Nope */
        return CPL_ERROR_NONE;
    }

    /* Determine amp layout based on number of amps.  This ends up
     * having to assume more about the amp layout than I'd like, but
     * it should be good enough for CCDs with one serial register and
     * two amps or with two serial registers and four (or more) amps,
     * which includes 4MOST.  The case not handled is two serial
     * registers each read through a single amp, but this could be
     * implemented if we did a bit more interrogation of the FITS
     * headers.  The assumptions here are similar to the ones I had
     * to make in qmost_ccdproc.c:qmost_detector_regions(). */
    nampsy = namps > 2 ? 2 : 1;
    nampsx = namps / nampsy;

    if(nampsx * nampsy != namps) {
        TIDY;
        return cpl_error_set_message(cpl_func, CPL_ERROR_UNSUPPORTED_MODE,
                                     "unsupported number of amps: %d, "
                                     "must be 1 or an even number",
                                     namps);
    }

    /* Size of the image, and size of an amp */
    nximg = cpl_image_get_size_x(img);
    nyimg = cpl_image_get_size_y(img);

    nxamp = nximg / nampsx;
    nyamp = nyimg / nampsy;

    /* Number of equations = number of amp boundaries available */
    neq = (nampsx - 1) * nampsy + (nampsy - 1) * nampsx;

    /* Create matrices, they are initialized to zero */
    p = cpl_matrix_new(neq, namps-1);
    r = cpl_matrix_new(neq, 1);

    /* Reference background level = median of the entire first amp.
     * This should match the region we use to measure gain of the
     * first amp in qmost_findgain. */
    tmpimg = cpl_image_extract(img, 1, 1, nxamp, nyamp);
    if(tmpimg == NULL) {
        TIDY;
        return cpl_error_set_message(cpl_func, cpl_error_get_code(),
                                     "failed to extract image "
                                     "section for amp 1: "
                                     "[%d:%d,%d:%d]",
                                     1, nxamp, 1, nyamp);
    }

    if(qmost_skylevel_image(tmpimg,
                            -1000,
                            65535,
                            -FLT_MAX, FLT_MAX, 0,
                            &refsky,
                            &refnoise) != CPL_ERROR_NONE) {
        TIDY;
        return cpl_error_set_message(cpl_func, cpl_error_get_code(),
                                     "could not compute sky level "
                                     "for amp 1");
    }

    cpl_image_delete(tmpimg);
    tmpimg = NULL;

    /* Form simultaneous equations.  The sky difference is first
     * calculated across each amp boundary, and each resulting
     * measurement is the right hand side of an equation of the form:
     * sky_b - sky_a = diff_ab
     * for amps i, j in [0,namps-1].  We further define sky_0 = 0.
     * Seeking the least-squares solution to these simultaneous
     * equations we can write them down in matrix form P q = r, where
     * the elements of these matrices are;
     * p_ak = -1
     * p_bk =  1
     * r_k = diff_ab
     * The constraint sky_0 = 0 removes the first column of P so its
     * size, the size of r, and all column indices are decreased by 1
     * accordingly where there is no entry for the first amp. */
    ieq = 0;

    /* Loop over amp boundaries in x */
    for(iampy = 0; iampy < nampsy; iampy++) {
        yl = iampy * nyamp + 1;
        yh = (iampy + 1) * nyamp;

        for(iampx = 1; iampx < nampsx; iampx++) {
            /* Coordinate of the pixel before the the amp boundary */
            xm = iampx * nxamp;

            /* Define strips either side */
            xl = xm - stripsize + 1;
            xh = xm + stripsize;

            /* Compute sky difference */
            if(qmost_skydiff_subset(img,
                                    xl, yl, xm, yh,
                                    xm+1, yl, xh, yh,
                                    2,
                                    &skydiff,
                                    &noisdiff,
                                    &wt) != CPL_ERROR_NONE) {
                TIDY;
                return cpl_error_set_message(cpl_func, cpl_error_get_code(),
                                             "could not compute sky "
                                             "difference between regions "
                                             "[%d:%d,%d:%d] and "
                                             "[%d:%d,%d:%d]",
                                             xl, yl, xm, yh,
                                             xm+1, yl, xh, yh);
            }

            /* What number amp is this?  e2v number them anticlockwise
             * so the top row x needs to be reversed. */
            if(iampy % 2) {
                iampa = iampy * nampsx + nampsx - iampx;
                iampb = iampy * nampsx + nampsx - iampx - 1;
            }
            else {
                iampa = iampy * nampsx + iampx - 1;
                iampb = iampy * nampsx + iampx;
            }

            /* Add equation.  The first amp level is defined to be
             * zero. */
            if(iampa > 0) {
                cpl_matrix_set(p, ieq, iampa - 1, -wt);
            }
            if(iampb > 0) {
                cpl_matrix_set(p, ieq, iampb - 1, wt);
            }
            cpl_matrix_set(r, ieq, 0, wt * skydiff / refsky);

            ieq++;
        }
    }

    /* Loop over amp boundaries in y */
    for(iampy = 1; iampy < nampsy; iampy++) {
        /* Coordinate of the pixel before the the amp boundary */
        ym = iampy * nyamp;
        
        /* Define strips either side */
        yl = ym - stripsize + 1;
        yh = ym + stripsize;
        
        for(iampx = 0; iampx < nampsx; iampx++) {
            xl = iampx * nxamp + 1;
            xh = (iampx + 1) * nxamp;
            
            /* Compute sky difference */
            if(qmost_skydiff_subset(img,
                                    xl, yl, xh, ym,
                                    xl, ym+1, xh, yh,
                                    0,
                                    &skydiff,
                                    &noisdiff,
                                    &wt) != CPL_ERROR_NONE) {
                TIDY;
                return cpl_error_set_message(cpl_func, cpl_error_get_code(),
                                             "could not compute sky "
                                             "difference between regions "
                                             "[%d:%d,%d:%d] and "
                                             "[%d:%d,%d:%d]",
                                             xl, xh, yl, ym,
                                             xl, xh, ym+1, yh);
            }

            /* What number amp is this?  e2v number them anticlockwise
             * so the top row x needs to be reversed.  This is a bit
             * subtle because a and b are on different rows. */
            if(iampy % 2) {
                iampa = (iampy - 1) * nampsx + iampx;
                iampb = iampy * nampsx + nampsx - iampx - 1;
            }
            else {
                iampa = (iampy - 1) * nampsx + nampsx - iampx - 1;
                iampb = iampy * nampsx + iampx;
            }

            /* Add equation.  The first amp level is defined to be
             * zero. */
            if(iampa > 0) {
                cpl_matrix_set(p, ieq, iampa - 1, -wt);
            }
            if(iampb > 0) {
                cpl_matrix_set(p, ieq, iampb - 1, wt);
            }
            cpl_matrix_set(r, ieq, 0, wt * skydiff / refsky);

            ieq++;
        }
    }

    /* Solve the linear system */
    prestate = cpl_errorstate_get();

    q = cpl_matrix_solve_normal(p, r);
    if(q == NULL) {
        /* Trap singular matrix and replace with all zeros (no
         * correction). */
        if(cpl_error_get_code() == CPL_ERROR_SINGULAR_MATRIX) {
            cpl_errorstate_set(prestate);

            cpl_msg_warning(cpl_func, "singular matrix");

            q = cpl_matrix_new(namps-1, 1);
        }
        else {
            TIDY;
            return cpl_error_set_message(cpl_func, cpl_error_get_code(),
                                         "could not solve matrix");
        }
    }

    /* Get image */
    data = cpl_image_get_data_float(img);
    if(data == NULL) {
        return cpl_error_set_message(cpl_func, cpl_error_get_code(),
                                     "could not get float pointer to image");
    }

    /* Do gain correction, and record the values in the header */
    for(iampy = 0; iampy < nampsy; iampy++) {
        yl = iampy * nyamp + 1;
        yh = (iampy + 1) * nyamp;

        for(iampx = 0; iampx < nampsx; iampx++) {
            xl = iampx * nxamp + 1;
            xh = (iampx + 1) * nxamp;

            /* Amp number (from zero) */
            if(iampy % 2) {
                iampa = iampy * nampsx + nampsx - iampx - 1;
            }
            else {
                iampa = iampy * nampsx + iampx;
            }

            /* Retrieve result */
            relgain = 1.0;
            if(iampa > 0) {
                relgain += cpl_matrix_get(q, iampa - 1, 0);

                /* Apply correction */
                for(y = yl; y <= yh; y++) {
                    for(x = xl; x <= xh; x++) {
                        data[(y-1)*nximg+(x-1)] /= relgain;
                    }
                }
            }

            /* Record in header */
            key = cpl_sprintf("ESO DRS AMP%d GAINCOR", iampa + 1);
            if(key == NULL) {
                TIDY;
                return cpl_error_set_message(cpl_func, cpl_error_get_code(),
                                             "could not format string "
                                             "for AMP%d GAINCOR keyword",
                                             iampa + 1);
            }

            cpl_propertylist_update_float(hdr, key, relgain);
            cpl_propertylist_set_comment(hdr, key,
                                         "Relative gain for amp");

            cpl_free(key);
            key = NULL;
        }
    }

    cpl_matrix_delete(p);
    p = NULL;

    cpl_matrix_delete(q);
    q = NULL;

    cpl_matrix_delete(r);
    r = NULL;

    return CPL_ERROR_NONE;
}

/*----------------------------------------------------------------------------*/
/**
 * @brief   Compute sky difference between two regions of an image.
 *
 * The specified image regions "a" and "b" are extracted, "a" is
 * flipped, and the results are then differenced and sky background of
 * the difference image computed and returned.  The purpose of the
 * axis flip is so the regions differenced are mirrored across the amp
 * boundary.
 *
 * @param   img              (Given)    The image to process.  The
 *                                      data type must be
 *                                      CPL_TYPE_FLOAT.
 * @param   xla              (Given)    The lower left x coordinate
 *                                      (numbering from 1) for region
 *                                      a.
 * @param   yla              (Given)    The lower left y coordinate
 *                                      (numbering from 1) for region
 *                                      a.
 * @param   xha              (Given)    The upper right x coordinate
 *                                      (numbering from 1) for region
 *                                      a.
 * @param   yha              (Given)    The upper right y coordinate
 *                                      (numbering from 1) for region
 *                                      a.
 * @param   xlb              (Given)    The lower left x coordinate
 *                                      (numbering from 1) for region
 *                                      b.
 * @param   ylb              (Given)    The lower left y coordinate
 *                                      (numbering from 1) for region
 *                                      b.
 * @param   xhb              (Given)    The upper right x coordinate
 *                                      (numbering from 1) for region
 *                                      b.
 * @param   yhb              (Given)    The upper right y coordinate
 *                                      (numbering from 1) for region
 *                                      b.
 * @param   flip             (Given)    0: mirror in vertical.
 *                                      2: mirror in horizontal.
 * @param   skydiff          (Returned) The resulting median sky level
 *                                      of the difference image.
 * @param   noisdiff         (Returned) The resulting sigma
 *                                      ("1.48*MAD") of the difference
 *                                      image.
 * @param   wt               (Returned) Weight to be used in fitting.
 *                                      This is 1.0 if there were
 *                                      valid pixels in the overlap
 *                                      region and 0.0 if there were
 *                                      not.
 *
 * @return  cpl_error_code
 *
 * @retval  CPL_ERROR_NONE                If everything is OK.
 *
 * @author  Jonathan Irwin, CASU
 */
/*----------------------------------------------------------------------------*/

cpl_error_code qmost_skydiff_subset (
    cpl_image *img,
    int xla,
    int yla,
    int xha,
    int yha,
    int xlb,
    int ylb,
    int xhb,
    int yhb,
    int flip,
    float *skydiff,
    float *noisdiff,
    float *wt)
{
    cpl_image *aimg = NULL;
    cpl_image *bimg = NULL;
    int npix;

#undef TIDY
#define TIDY                                    \
    if(aimg != NULL) {                          \
        cpl_image_delete(aimg);                 \
        aimg = NULL;                            \
    }                                           \
    if(bimg != NULL) {                          \
        cpl_image_delete(bimg);                 \
        bimg = NULL;                            \
    }

    aimg = cpl_image_extract(img, xla, yla, xha, yha);
    if(aimg == NULL) {
        TIDY;
        return cpl_error_set_message(cpl_func, cpl_error_get_code(),
                                     "failed to extract image "
                                     "section: "
                                     "[%d:%d,%d:%d]",
                                     xla, xha, yla, yha);
    }

    bimg = cpl_image_extract(img, xlb, ylb, xhb, yhb);
    if(bimg == NULL) {
        TIDY;
        return cpl_error_set_message(cpl_func, cpl_error_get_code(),
                                     "failed to extract image "
                                     "section: "
                                     "[%d:%d,%d:%d]",
                                     xlb, xhb, ylb, yhb);
    }

    npix = (yhb - ylb + 1) * (xhb - xlb + 1);

    cpl_image_flip(aimg, flip);
    cpl_image_subtract(bimg, aimg);

    if(cpl_image_count_rejected(bimg) < npix) {
        *wt = 1.0;
    }
    else {
        *wt = 0.0;
    }

    *skydiff = 0.0;

    if(qmost_skylevel_image(bimg,
                            -65535,
                            65535,
                            -FLT_MAX, FLT_MAX, 0,
                            skydiff,
                            noisdiff) != CPL_ERROR_NONE) {
        TIDY;
        return cpl_error_set_message(cpl_func, cpl_error_get_code(),
                                     "could not compute sky level");
    }

    cpl_image_delete(aimg);
    aimg = NULL;
    
    cpl_image_delete(bimg);
    bimg = NULL;

    return CPL_ERROR_NONE;
}

/**@}*/
