/*
 * This file is part of the QMOST Pipeline
 * Copyright (C) 2002-2022 European Southern Observatory
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

/*-----------------------------------------------------------------------------
                                Includes
 -----------------------------------------------------------------------------*/

#include "qmost_pca.h"

/*----------------------------------------------------------------------------*/
/**
 * @defgroup qmost_pca_test  Unit test of qmost_pca
 *
 */
/*----------------------------------------------------------------------------*/

/**@{*/

/*----------------------------------------------------------------------------*/
/**
  @brief    Unit test of qmost_pca_form_covar
 */
/*----------------------------------------------------------------------------*/

static void test_qmost_pca_form_covar(void)
{
    double sample[40] = {
        1.7, -1.3,  2.7, -5.3, -0.3,  1.7, -2.3, -2.3,  0.7,  4.7,
        2.4,  0.4,  4.4, -2.6, -0.6,  0.4, -3.6, -3.6, -2.6,  5.4,
        2.5,  1.5, -1.5, -1.5, -1.5,  2.5, -2.5, -3.5,  1.5,  2.5,
        5.2, -2.8,  0.2,  0.2, -0.8, -0.8,  1.2, -1.8,  3.2, -3.8
    };
    float skymask[10];
    int nspec = 4;
    int npix = 10;
    int i, j;

    double answer[4][4] = {
        {  76.1,  70.2,  37.5,  -3.4 },
        {  70.2,  94.4,  37.0, -14.8 },
        {  37.5,  37.0,  48.5,   6.0 },
        {  -3.4, -14.8,   6.0,  65.6 }
    };

    double newanswer[4][4];

    double **covar = NULL;

    /* Simple case where everything's included */
    for(i = 0; i < 10; i++) {
        skymask[i] = 1.0;
    }

    covar = qmost_pca_form_covar(sample, skymask, nspec, npix);

    for(i = 0; i < 4; i++) {
        for(j = 0; j < 4; j++) {
            cpl_test_abs(answer[i][j], covar[i][j], 1.0e-10);
        }
    }

    cpl_free(covar[0]);
    cpl_free(covar);

    /* Now test mask by masking out an element */
    skymask[5] = 0;

    /* Compute effect this has on the answer */
    for(i = 0; i < 4; i++) {
        for(j = 0; j < 4; j++) {
            newanswer[i][j] = answer[i][j] - sample[10*i+5] * sample[10*j+5];
        }
    }

    covar = qmost_pca_form_covar(sample, skymask, nspec, npix);

    for(i = 0; i < 4; i++) {
        for(j = 0; j < 4; j++) {
            cpl_test_abs(newanswer[i][j], covar[i][j], 1.0e-10);
        }
    }

    cpl_free(covar[0]);
    cpl_free(covar);
}

/*----------------------------------------------------------------------------*/
/**
  @brief    Unit test of qmost_pca_get_eigen
 */
/*----------------------------------------------------------------------------*/

static void test_qmost_pca_get_eigen(void)
{
    /* Test matrix */
    double a[16] = {   4,  -30,    60,    -35,
                     -30,  300,  -675,    420,
                      60, -675,  1620,  -1050,
                     -35,  420, -1050,    700 };

    /* Sorted eigenvalues of test matrix */
    double answers[4] = { 2.585253810928920302e+03,
                          3.710149136512757195e+01,
                          1.478054844778103449e+00,
                          1.666428611718406561e-01 };

    int n = 4;

    double *covar_buf = NULL;
    double **covar = NULL;
    double *eigenvalues = NULL;
    double *eigenfrac = NULL;
    double **eigenvectors = NULL;

    int i, j, k;
    double sum, chk, maxerr, tmp;
    double diagvar;

    covar_buf = cpl_malloc(n*n * sizeof(double));
    covar = cpl_malloc(n * sizeof(double *));

    memcpy(covar_buf, a, n*n * sizeof(double));

    for(i = 0; i < n; i++)
        covar[i] = covar_buf + i*n;

    qmost_pca_get_eigen(covar, n,
                        &eigenvalues, &eigenfrac, &eigenvectors,
                        0);

    /* Check solution satisfies the eigenvalue equation:
     * A V - lambda * V = 0 where V = "eigenvectors" */
    maxerr = 0;

    for(i = 0; i < n; i++) {
        for(j = 0; j < n; j++) {
            sum = 0;

            for(k = 0; k < n; k++) {
                /* Aik Vkj */
                sum += a[i*n+k] * eigenvectors[k][j];
            }

            chk = fabs(sum - eigenvalues[j] * eigenvectors[i][j]);
            if(chk > maxerr) {
                maxerr = chk;
            }
        }
    }

    cpl_test_lt(maxerr, 1.0e-10);

    /* Check columns of V are orthogonal by calculating V^T V, this
     * should be diagonal. */
    maxerr = 0;

    for(i = 0; i < n; i++) {
        for(j = 0; j < n; j++) {
            sum = 0;

            for(k = 0; k < n; k++) {
                /* V^Tik Vkj = Vki Vkj */
                sum += eigenvectors[k][i] * eigenvectors[k][j];
            }

            if(i != j) {
                chk = fabs(sum);
                if(chk > maxerr) {
                    maxerr = chk;
                }
            }
        }
    }

    cpl_test_lt(maxerr, 1.0e-12);

    /* Check eigenvalues and eigenfrac */
    sum = 0;

    for(i = 0; i < n; i++) {
        cpl_test_abs(eigenvalues[i], answers[i], 1.0e-10*fabs(answers[i]));
        sum += answers[i];
    }

    for(i = 0; i < n; i++) {
        cpl_test_abs(eigenfrac[i], answers[i]/sum, 1.0e-10);
    }

    cpl_free(covar[0]);
    cpl_free(covar);
    cpl_free(eigenvalues);
    cpl_free(eigenfrac);
    cpl_free(eigenvectors[0]);
    cpl_free(eigenvectors);

    /* Now check with nonzero diagvar, this should modify eigenvectors
     * by the given amount if it's larger than 0.01*sum, which here
     * it is. */
    diagvar = 1.0;

    covar_buf = cpl_malloc(n*n * sizeof(double));
    covar = cpl_malloc(n * sizeof(double *));

    memcpy(covar_buf, a, n*n * sizeof(double));

    for(i = 0; i < n; i++)
        covar[i] = covar_buf + i*n;

    qmost_pca_get_eigen(covar, n,
                        &eigenvalues, &eigenfrac, &eigenvectors,
                        diagvar);

    sum = 0;

    for(i = 0; i < n; i++) {
        tmp = answers[i] - diagvar;
        cpl_test_abs(eigenvalues[i], tmp, 1.0e-10*fabs(tmp));

        if(tmp > 0)
            sum += tmp;
    }

    for(i = 0; i < n; i++) {
        cpl_test_abs(eigenfrac[i], (answers[i]-1.0)/sum, 1.0e-10);
    }

    cpl_free(covar[0]);
    cpl_free(covar);
    cpl_free(eigenvalues);
    cpl_free(eigenfrac);
    cpl_free(eigenvectors[0]);
    cpl_free(eigenvectors);

    /* ...and here it isn't */
    diagvar = 1000.0;

    covar_buf = cpl_malloc(n*n * sizeof(double));
    covar = cpl_malloc(n * sizeof(double *));

    memcpy(covar_buf, a, n*n * sizeof(double));

    for(i = 0; i < n; i++)
        covar[i] = covar_buf + i*n;

    qmost_pca_get_eigen(covar, n,
                        &eigenvalues, &eigenfrac, &eigenvectors,
                        1000.0);

    /* Apply the limit */
    sum = 0;
    for(i = 0; i < n; i++) {
        sum += answers[i];
    }
    
    diagvar = 0.01 * sum;

    sum = 0;

    for(i = 0; i < n; i++) {
        tmp = answers[i] - diagvar;
        cpl_test_abs(eigenvalues[i], tmp, 1.0e-6*fabs(tmp));

        if(tmp > 0)
            sum += tmp;
    }

    for(i = 0; i < n; i++) {
        cpl_test_abs(eigenfrac[i], (answers[i]-diagvar)/sum, 1.0e-6);
    }

    cpl_free(covar[0]);
    cpl_free(covar);
    cpl_free(eigenvalues);
    cpl_free(eigenfrac);
    cpl_free(eigenvectors[0]);
    cpl_free(eigenvectors);
}

/*----------------------------------------------------------------------------*/
/**
  @brief    Unit test of qmost_pca_trans_eigen
 */
/*----------------------------------------------------------------------------*/

static void test_qmost_pca_trans_eigen(void)
{
    double eigen_buf[4] = { 1, -1,
                            1,  1 };

    double *eigen[2];

    double data[8] = { 1, -1,  1, -1,
                       1,  1, -1, -1 };

    double ansa[2][4] = { { CPL_MATH_SQRT1_2, 0, 0, -CPL_MATH_SQRT1_2 },
                          { 0, CPL_MATH_SQRT1_2, -CPL_MATH_SQRT1_2, 0 } };
    double ansb[2][4] = { { 1, 0,  0,  0 },
                          { 0, 0, -1,  0 } };

    float skymaska[4] = { 1, 1, 1, 1 };
    float skymaskb[4] = { 1, 0, 1, 0 };

    int ns = 2;
    int np = 4;
    int i, j;

    double **result = NULL;

    /* No mask */
    eigen[0] = &(eigen_buf[0]);
    eigen[1] = &(eigen_buf[2]);

    result = qmost_pca_trans_eigen(eigen, data, skymaska, ns, np);

    for(i = 0; i < ns; i++) {
        for(j = 0; j < np; j++) {
            cpl_test_abs(result[i][j], ansa[i][j], DBL_EPSILON);
        }
    }

    cpl_free(result[0]);
    cpl_free(result);

    /* With mask */
    result = qmost_pca_trans_eigen(eigen, data, skymaskb, ns, np);

    for(i = 0; i < ns; i++) {
        for(j = 0; j < np; j++) {
            cpl_test_abs(result[i][j], ansb[i][j], DBL_EPSILON);
        }
    }

    cpl_free(result[0]);
    cpl_free(result);
}

/*----------------------------------------------------------------------------*/
/**
  @brief    Unit test of qmost_pca_recon_spec
 */
/*----------------------------------------------------------------------------*/

static void test_qmost_pca_recon_spec(void)
{
    double eigen_buf[8] = {
        1, 0,  0, -1,
        0, 1, -1,  0
    };
    double *eigen[2];

    float recon[4];

    double indata[4] = {
        2, 1, -1, -2  /* 2*first + 1*second eigens */
    };

    float skymaska[4] = { 1, 1, 1, 1 };
    float invara[8] = { 1, 1, 1, 1,
                        1, 1, 1, 1 };
    float ansa[4] = { 4, 2, -2, -4 };
    
    float skymaskb[4] = { 1, 0, 1, 1 };
    float ansb[4] = { 4, 1, -1, -4 };

    float invarc[8] = { 1, 1, 1, 0,
                        1, 1, 1, 1 };
    float ansc[4] = { 2, 2, -2, -2 };

    int i;

    eigen[0] = &(eigen_buf[0]);
    eigen[1] = &(eigen_buf[4]);

    /* Nominal case without masking */
    qmost_pca_recon_spec(eigen, indata, 4, skymaska, invara, 2, recon);

    for(i = 0; i < 4; i++) {
        cpl_test_abs(recon[i], ansa[i], DBL_EPSILON);
    }

    /* Now modify answer using sky mask */
    qmost_pca_recon_spec(eigen, indata, 4, skymaskb, invara, 2, recon);

    for(i = 0; i < 4; i++) {
        cpl_test_abs(recon[i], ansb[i], DBL_EPSILON);
    }

    /* And using variance array */
    qmost_pca_recon_spec(eigen, indata, 4, skymaska, invarc, 2, recon);

    for(i = 0; i < 4; i++) {
        cpl_test_abs(recon[i], ansc[i], DBL_EPSILON);
    }
}

/*----------------------------------------------------------------------------*/
/**
  @brief    Unit tests of qmost_pca module
 */
/*----------------------------------------------------------------------------*/

int main(void)
{
    cpl_test_init(PACKAGE_BUGREPORT, CPL_MSG_WARNING);

    test_qmost_pca_form_covar();
    test_qmost_pca_get_eigen();
    test_qmost_pca_trans_eigen();
    test_qmost_pca_recon_spec();

    return cpl_test_end(0);
}

/**@}*/
