/*
 * =======================================================================================
 *
 *      Author:   Jan Eitzinger (je), jan.eitzinger@fau.de
 *      Copyright (c) 2019 RRZE, University Erlangen-Nuremberg
 *
 *      Permission is hereby granted, free of charge, to any person obtaining a copy
 *      of this software and associated documentation files (the "Software"), to deal
 *      in the Software without restriction, including without limitation the rights
 *      to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 *      copies of the Software, and to permit persons to whom the Software is
 *      furnished to do so, subject to the following conditions:
 *
 *      The above copyright notice and this permission notice shall be included in all
 *      copies or substantial portions of the Software.
 *
 *      THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 *      IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 *      FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 *      AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 *      LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 *      OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 *      SOFTWARE.
 *
 * =======================================================================================
 */
#define _GNU_SOURCE
#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <time.h>
#include <limits.h>
#include <float.h>
#include <math.h>
#include <omp.h>

#include <likwid-marker.h>

#define ARRAY_ALIGNMENT 64
#define HLINE "----------------------------------------------------------------------------\n"

#ifdef TIMING
#define T(TYPE) timer[TYPE].val +=
#define TIMER_START S=getTimeStamp()
#define TIMER_STOP E=getTimeStamp()
#else
#define T(TYPE)
#define TIMER_START
#define TIMER_STOP
#endif

#define KNRM  "\x1B[0m"
#define KRED  "\x1B[31m"
#define KGRN  "\x1B[32m"

#define CHECK_LESS_THAN(a, b, OP)\
    if(a >= b) {\
        printf("%sTest %s failed\n", KRED, OP);\
        printf("%s",KNRM);\
    } else {\
        printf("%sTest %s success\n",KGRN, OP);\
        printf("%s",KNRM);\
    }

typedef enum component {
    AXPBY = 0,
    DOT_PRODUCT,
    APPLY_STENCIL,
    CG_SOLVER,
    NUMTIMERS
} component;

typedef struct {
    char* label;
    double val;
} timerType;

typedef struct {
    double* ptr;
    int size_x; //number of grid points in x dimension
    int size_y; //number of grid points in y dimension
} gridType;

typedef struct {
    double h_x;
    double h_y;
    int niter;
    double tolerance;
} solverType;

 /********************/
 /* Helper functions */
 /********************/
double getTimeStamp()
{
    struct timespec ts;
    clock_gettime(CLOCK_MONOTONIC, &ts);
    return (double)ts.tv_sec + (double)ts.tv_nsec * 1.e-9;
}

void allocGrid(
        int size_x,
        int size_y,
        size_t bytesPerWord,
        gridType* grid)
{
    posix_memalign((void**) &(grid->ptr),
            ARRAY_ALIGNMENT, size_x * size_y * bytesPerWord );
    grid->size_x = size_x;
    grid->size_y = size_y;
}

void copyGrid(
        gridType* src,
        gridType* dst)
{
    int size = (src->size_x * src->size_y);
    double* S = src->ptr;
    double* D = dst->ptr;

    for ( int i=0; i<size; i++) {
        D[i] = S[i];
    }
}

void freeGrid(gridType* grid)
{
    free(grid->ptr);
}

void init(gridType* x,
        gridType* u_sine,
        gridType* rhs_sine,
        solverType* solver)
{
    int size = (x->size_x * x->size_y);
    int size_x = x->size_x;
    int size_y = x->size_y;
    double* P = x->ptr;
    unsigned int seed = 0;

   // initialize x
#pragma omp parallel firstprivate(seed)
    {
        seed += omp_get_thread_num();
#pragma omp for
        for ( int i=0; i<size; i++) {
            P[i] = rand_r(&seed)/(double) RAND_MAX;
        }
    }

    P = u_sine->ptr;
    // initialize u_sine
#pragma omp parallel for
    for(int j=0; j<size_y; j++) {
        for(int i=0; i<size_x; i++) {
            P[j*size_x + i] =
                sin(M_PI * i * solver->h_x)*
                sin(M_PI * j * solver->h_y);
        }
    }

    P = rhs_sine->ptr;
    // initialize rhs_sine
#pragma omp parallel for
    for(int j=0; j<size_y; j++) {
        for(int i=0; i<size_x; i++) {
            P[j*size_x + i] =
                2.0 * M_PI * M_PI *
                sin(M_PI * i * solver->h_x)*
                sin(M_PI * j * solver->h_y);
        }
    }
}

 /********************/
 /* Basic operations */
 /********************/
/* acpby - Calculates res[:] = a*x[:] + b*y[:] */
double axpby(
        gridType* res,
        double a,
        gridType* x,
        double b,
        gridType* y)
{
    double S=0.0, E=0.0;
    double* R = res->ptr;
    double* X = x->ptr;
    double* Y = y->ptr;
    int size = x->size_x * x->size_y;

    TIMER_START;
#pragma omp parallel 
    {
    LIKWID_MARKER_START("axpby");
#pragma omp for
    for(int i=0; i<size; i++) {
        R[i] = (a * X[i]) + (b * Y[i]);
    }
    LIKWID_MARKER_STOP("axpby");
    }
    TIMER_STOP;

    return E-S;
}

double dotProduct(
        gridType* x,
        gridType* y,
        double* l2sq)
{
    double S=0.0, E=0.0;
    double l2_sq = 0.0;
    int size = x->size_x * x->size_y;
    double* X = x->ptr;
    double* Y = y->ptr;

    TIMER_START;
#pragma omp parallel for reduction(+:l2_sq)
    for(int i=0; i<size; i++) {
        l2_sq += X[i] * Y[i];
    }
    TIMER_STOP;

    (*l2sq) = l2_sq;

    return E-S;
}


/* Applies stencil operation on to u */
/* i.e., res = A*u */
double applyStencil(
        gridType* res,
        gridType* u,
        solverType* solver)
{
    double S=0.0, E=0.0;
    double* R = res->ptr;
    double* U = u->ptr;
    int size_x = res->size_x;
    int size_y = res->size_y;

    double w_x = 1.0 / (solver->h_x * solver->h_x);
    double w_y = 1.0 / (solver->h_y * solver->h_y);
    double w_c = 2.0 * (w_x + w_y);

    TIMER_START;
#pragma omp parallel for
    for ( int j=1; j<size_y-1; j++) {
        for ( int i=1; i<size_x-1; i++) {
            R[j*size_x + i] =
                w_c * U[j*size_x + i] -
                w_y * (U[(j+1)*size_x + i] + U[(j-1)*size_x + i]) -
                w_x * (U[j*size_x + i+1] + U[j*size_x + i-1]);
        }
    }
    TIMER_STOP;

    return E-S;
}

/* compute residual */
void computeResidual(gridType* residual,
        gridType* x,
        gridType* rhs_sine,
        solverType* solver,
        double* res_start)
{
    applyStencil(residual, x, solver);
    axpby(residual, 1.0, rhs_sine, -1.0, residual);
    dotProduct(residual, residual, res_start);
}

double CG( gridType* x,
        gridType* b,
        solverType* solver,
        int* iter_end,
        timerType* timer)
{
    double S=0.0, E=0.0;
    size_t bytesPerWord = sizeof(double);
    gridType p;
    gridType v;
    gridType r;
    int iter = 0;
    double lambda = 0; double alpha_0 = 0, alpha_1 = 0;
    double tolSquared = solver->tolerance * solver->tolerance;

    allocGrid(x->size_x, x->size_y, bytesPerWord, &p);
    allocGrid(x->size_x, x->size_y, bytesPerWord, &v);
    allocGrid(x->size_x, x->size_y, bytesPerWord, &r);

    computeResidual(&p, x, b, solver, &alpha_0);
    allocGrid(x->size_x, x->size_y, bytesPerWord, &r);
    copyGrid(&p, &r);

    S = getTimeStamp();
    while( (iter < solver->niter) && (alpha_0 > tolSquared) ) {
	#pragma omp parallel
        LIKWID_MARKER_START("cg");
        T(APPLY_STENCIL) applyStencil(&v, &p, solver);
        T(DOT_PRODUCT) dotProduct(&v, &p, &lambda);
        lambda = alpha_0/lambda;
        //Update x
        T(AXPBY) axpby(x, 1.0, x, lambda, &p);
        //Update r
        T(AXPBY) axpby(&r, 1.0, &r, -lambda, &v);
        T(DOT_PRODUCT) dotProduct(&r, &r, &alpha_1);
        //Update p
        T(AXPBY) axpby(&p, 1.0, &r, alpha_1/alpha_0, &p);
        alpha_0 = alpha_1;
        printf("iter = %d, res = %.15e\n", iter, alpha_0);
        ++iter;
	#pragma omp parallel
        LIKWID_MARKER_STOP("cg");
    }
    E = getTimeStamp();

    freeGrid(&p);
    freeGrid(&v);
    freeGrid(&r);
    (*iter_end) = iter;

    return E-S;
}

int main (int argc, char** argv)
{
    size_t bytesPerWord = sizeof(double);
    int size_x = 0;
    int size_y = 0;
    double res_start, err_start;
    double res_sine_cg, err_sine_cg;
    int itermax = 50;
    int iter_sine_cg;
    gridType u_sine, rhs_sine, x, residual;
    timerType timer[NUMTIMERS] = {
        {"axpby         ", 0.0},
        {"dot_product   ", 0.0},
        {"apply_stencil ", 0.0},
        {"cg            ", 0.0}
    };

    if ( argc > 2 ) {
        size_x = atoi(argv[1]);
        size_y = atoi(argv[2]);
        if ( argc == 4 ) {
            itermax = atoi(argv[3]);
        }
    } else {
        printf("Usage: %s <outer dimension y> <inner dimension x>\n", argv[0]);
        exit(EXIT_SUCCESS);
    }

    solverType solver = {(double)1.0/(size_x-1.0), (double)1.0/(size_y-1.0), itermax, 1e-8};

    allocGrid(size_x, size_y, bytesPerWord, &u_sine);
    allocGrid(size_x, size_y, bytesPerWord, &rhs_sine);
    allocGrid(size_x, size_y, bytesPerWord, &x);
    allocGrid(size_x, size_y, bytesPerWord, &residual);

    LIKWID_MARKER_INIT;
#pragma omp parallel
    {
        LIKWID_MARKER_REGISTER("cg");
        LIKWID_MARKER_REGISTER("axpby");
    }


    init(&x, &u_sine, &rhs_sine, &solver);

    /* compute start error and residual */
    axpby(&residual, 1.0, &u_sine, -1.0, &x);
    dotProduct(&residual, &residual, &err_start);
    computeResidual(&residual, &x, &rhs_sine, &solver, &res_start);

    timer[CG_SOLVER].val = CG( &x,
            &rhs_sine,
            &solver,
            &iter_sine_cg,
            timer);
    printf("CG iterations = %d\n", iter_sine_cg);
    printf("Performance CG = %f [MLUP/s]\n",
            (double)(iter_sine_cg)*size_x*size_y*
            1e-6/timer[CG_SOLVER].val);

    /* compute end error and residual */
    axpby(&residual, 1.0, &u_sine, -1.0, &x);
    dotProduct(&residual, &residual, &err_sine_cg);
    computeResidual(&residual, &x, &rhs_sine, &solver, &res_sine_cg);

    CHECK_LESS_THAN(res_sine_cg,res_start,"Solver::CG - residual check")
    CHECK_LESS_THAN(err_sine_cg,err_start,"Solver::CG - error check")

#ifdef TIMING
    printf(HLINE);
    for (int i=0; i<NUMTIMERS; i++){
        printf("%s%11.2fs\n",timer[i].label, timer[i].val);
    }
    printf(HLINE);
#endif

#ifdef DEBUG
    printf("Initial residual = %.9e, curr residual CG = %.9e\n", sqrt(res_start), sqrt(res_sine_cg));
    printf("Initial error = %.9e, curr error CG = %.9e\n", sqrt(err_start), sqrt(err_sine_cg));
#endif
    LIKWID_MARKER_CLOSE;
    return EXIT_SUCCESS;
}
