#include <torch/torch.h>
#include <iostream>
#include <cmath>
#include <cstdio>
#include <vector>

// self defined
#include "nanovoid_app.h"


auto options =
    torch::TensorOptions()
        .dtype(torch::kFloat32)
        .layout(torch::kStrided)
        .device(torch::kCPU)
        .requires_grad(true);


struct LaplacianOpImpl : torch::nn::Module {
    torch::Tensor conv_kernel;
    LaplacianOpImpl() {
        double kernel[1][1][3][3] = {{{{0.0, 1.0, 0.0}, {1.0, -4.0, 1.0}, {0.0, 1.0, 0.0}}}};
        torch::Tensor _kernel = torch::zeros({3, 3}, torch::dtype(torch::kFloat32).requires_grad(false));
        _kernel.unsqueeze_(0).unsqueeze_(0); // 1 1 3 3
        for (int i = 0; i < 3; ++i) {
            for (int j = 0; j < 3; ++j) {
                _kernel.index_put_({0, 0, i, j}, kernel[0][0][i][j]);
            }
        }
        conv_kernel = register_parameter("conv_kernel", _kernel);
    }

    torch::Tensor forward(torch::Tensor input, double dx=1.0, double dy=1.0) {
        bool unsqueezed = false;
        if(input.dim() == 2) {
            input.unsqueeze_(0);
            unsqueezed = true;
        }
        using namespace torch::indexing;
        torch::Tensor input1 = torch::cat({input.index({Slice(None, None, None), Slice(-1, None, None), Slice(None, None, None)}),
                                           input, 
                                           input.index({Slice(None, None, None), Slice(None, 1, None), Slice(None, None, None)})},
                                          1);
        torch::Tensor input2 = torch::cat({input1.index({Slice(None, None, None), Slice(None, None, None), Slice(-1, None, None)}),
                                           input1, 
                                           input1.index({Slice(None, None, None), Slice(None, None, None), Slice(None, 1, None)})},
                                          2);
        torch::Tensor conv_input = input2.unsqueeze_(1);
        torch::Tensor result = torch::nn::functional::conv2d(conv_input, conv_kernel.squeeze_(1)) / (dx * dy);
        if(unsqueezed) {
            result.squeeze_(0);
        }
        return result;
    }
};

TORCH_MODULE_IMPL(LaplacianOp, LaplacianOpImpl);

struct IrradiationSingleTimeStepImpl : torch::nn::Module {

    torch::Tensor energy_v0;
    torch::Tensor energy_i0;
    torch::Tensor kBT0;
    torch::Tensor kappa_v0;
    torch::Tensor kappa_i0;
    torch::Tensor kappa_eta0;
    torch::Tensor r_bulk0;
    torch::Tensor r_surf0;
    torch::Tensor p_casc0;
    torch::Tensor bias0;
    torch::Tensor vg0;
    torch::Tensor diff_i0;
    torch::Tensor diff_v0;
    torch::Tensor L0;
    double dt, dx, dy, eps;
    int _N;

    IrradiationSingleTimeStepImpl(double _dt, double _dx, double _dy, double _eps, int __N, ParameterSet p) {

        // energy_v0 = register_parameter("energy_v0", torch::randn(1, torch::dtype(torch::kFloat32).requires_grad(true))); // * 5 + 0.1);
        // energy_i0 = register_parameter("energy_i0", torch::randn(1, torch::dtype(torch::kFloat32).requires_grad(true))); // * 5 + 0.1);
        // kBT0 = register_parameter("kBT0", torch::randn(1, torch::dtype(torch::kFloat32).requires_grad(true))); // * 5 + 0.1);
        // kappa_v0 = register_parameter("kappa_v0", torch::randn(1, torch::dtype(torch::kFloat32).requires_grad(true))); // / 2.0 + 0.1);
        // kappa_i0 = register_parameter("kappa_i0", torch::randn(1, torch::dtype(torch::kFloat32).requires_grad(true))); // / 2.0 + 0.1);
        // kappa_eta0 = register_parameter("kappa_eta0", torch::randn(1, torch::dtype(torch::kFloat32).requires_grad(true))); // / 2.0 + 0.1);
        
        // // r_bulk0 = register_parameter("r_bulk0", torch::ones(1, torch::dtype(torch::kFloat32).requires_grad(false)) * (5.0 - 0.001));
        // // r_surf0 = register_parameter("r_surf0", torch::ones(1, torch::dtype(torch::kFloat32).requires_grad(false)) * (10.0 - 0.001));
        
        // // p_casc0 = register_parameter("p_casc0", torch::ones(1, torch::dtype(torch::kFloat32).requires_grad(false)) * (0.01 - 0.001));
        // // bias0 = register_parameter("bias0", torch::ones(1, torch::dtype(torch::kFloat32).requires_grad(false)) * (0.3 - 0.001));
        // // vg0 = register_parameter("vg0", torch::ones(1, torch::dtype(torch::kFloat32).requires_grad(false)) * (0.01 - 0.001));

        // diff_v0 = register_parameter("diff_v0", torch::randn(1, torch::dtype(torch::kFloat32).requires_grad(true))); // / 2.0 + 0.1);
        // diff_i0 = register_parameter("diff_i0", torch::randn(1, torch::dtype(torch::kFloat32).requires_grad(true))); // / 2.0 + 0.1);
        // L0 = register_parameter("L0", torch::randn(1, torch::dtype(torch::kFloat32).requires_grad(true))); // / 2.0 + 0.1);

        // torch::from_blob() version
        valueType p_energy_v0 = p.energy_v0;
        torch::Tensor energy_v0_value = torch::from_blob(&p_energy_v0, {1,}, torch::dtype(torch::kFloat32)).clone();
        energy_v0_value.index_put_({0,}, p.energy_v0);
        energy_v0 = register_parameter("energy_v0", energy_v0_value.clone().requires_grad_(true)); // * 5 + 0.1);
        // std::cout << "p.ev: " << p.energy_v0 << std::endl;
        // std::cout << "ev0 value: " << energy_v0_value << std::endl;
        // std::cout << "register ev: " << energy_v0 << std::endl;

        torch::Tensor energy_i0_value = torch::from_blob(&(p.energy_i0), {1}, torch::dtype(torch::kFloat32));
        energy_i0_value.index_put_({0,}, p.energy_i0);
        energy_i0 = register_parameter("energy_i0", energy_i0_value.clone().requires_grad_(true)); // * 5 + 0.1);

        torch::Tensor kBT0_value = torch::from_blob(&(p.kBT0), {1}, torch::dtype(torch::kFloat32));
        kBT0_value.index_put_({0,}, p.kBT0);
        kBT0 = register_parameter("kBT0", kBT0_value.clone().requires_grad_(true)); // * 5 + 0.1);

        torch::Tensor kappa_v0_value = torch::from_blob(&(p.kappa_v0), {1}, torch::dtype(torch::kFloat32));
        kappa_v0_value.index_put_({0,}, p.kappa_v0);
        kappa_v0 = register_parameter("kappa_v0", kappa_v0_value.clone().requires_grad_(true)); // / 2.0 + 0.1);

        torch::Tensor kappa_i0_value = torch::from_blob(&(p.kappa_i0), {1}, torch::dtype(torch::kFloat32));
        kappa_i0_value.index_put_({0,}, p.kappa_i0);
        kappa_i0 = register_parameter("kappa_i0", kappa_i0_value.clone().requires_grad_(true)); // / 2.0 + 0.1);

        torch::Tensor kappa_eta0_value = torch::from_blob(&(p.kappa_eta0), {1}, torch::dtype(torch::kFloat32));
        kappa_eta0_value.index_put_({0,}, p.kappa_eta0);
        kappa_eta0 = register_parameter("kappa_eta0", kappa_eta0_value.clone().requires_grad_(true)); // / 2.0 + 0.1);
        
        // r_bulk0 = register_parameter("r_bulk0", torch::ones(1, torch::dtype(torch::kFloat32).requires_grad(false)) * (5.0 - 0.001));
        // r_surf0 = register_parameter("r_surf0", torch::ones(1, torch::dtype(torch::kFloat32).requires_grad(false)) * (10.0 - 0.001));
        
        // p_casc0 = register_parameter("p_casc0", torch::ones(1, torch::dtype(torch::kFloat32).requires_grad(false)) * (0.01 - 0.001));
        // bias0 = register_parameter("bias0", torch::ones(1, torch::dtype(torch::kFloat32).requires_grad(false)) * (0.3 - 0.001));
        // vg0 = register_parameter("vg0", torch::ones(1, torch::dtype(torch::kFloat32).requires_grad(false)) * (0.01 - 0.001));

        torch::Tensor diff_v0_value = torch::from_blob(&(p.diff_v0), {1}, torch::dtype(torch::kFloat32));
        diff_v0_value.index_put_({0,}, p.diff_v0);
        diff_v0 = register_parameter("diff_v0", diff_v0_value.clone().requires_grad_(true)); // / 2.0 + 0.1);

        torch::Tensor diff_i0_value = torch::from_blob(&(p.diff_i0), {1}, torch::dtype(torch::kFloat32));
        diff_i0_value.index_put_({0,}, p.diff_i0);
        diff_i0 = register_parameter("diff_i0", diff_i0_value.clone().requires_grad_(true)); // / 2.0 + 0.1);

        torch::Tensor L0_value = torch::from_blob(&(p.L0), {1}, torch::dtype(torch::kFloat32));
        L0_value.index_put_({0,}, p.L0);
        L0 = register_parameter("L0", L0_value.clone().requires_grad_(true)); // / 2.0 + 0.1);
        
        dt = _dt;
        dx = _dx;
        dy = _dy;
        eps = _eps;
        _N = __N;
    }

    torch::Tensor log_with_mask(torch::Tensor mat, double eps=1e-6) {
        torch::Tensor mask = mat.le(eps).detach();
        torch::Tensor mat_masked = mat.masked_fill(mask, eps).clone();
        return torch::log(mat_masked);
    }

    torch::Tensor lap(torch::Tensor mat, double dx=1.0, double dy=1.0) {
        return mat * dx * dy;
    }

    torch::Tensor fix_deviations(torch::Tensor mat, double lb=0.0, double ub=1.0) {
        mat.masked_fill_(torch::ge(mat, ub).detach(), ub);
        mat.masked_fill_(torch::le(mat, lb).detach(), lb);
        return mat;
    }

    torch::Tensor forward(torch::Tensor cv, torch::Tensor ci, torch::Tensor eta) {
        torch::Tensor energy_v = torch::abs(energy_v0) + 0.001;
        // std::cout << "ev: " << energy_v << std::endl; 
        torch::Tensor energy_i = torch::abs(energy_i0) + 0.001;
        // std::cout << "ei: " << energy_i << std::endl; 
        torch::Tensor kBT = torch::abs(kBT0) + 0.001;
        // std::cout << "kBT: " << kBT << std::endl; 
        torch::Tensor kappa_v = torch::abs(kappa_v0) + 0.001;
        // std::cout << "kappa v: " << kappa_v << std::endl; 
        torch::Tensor kappa_i = torch::abs(kappa_i0) + 0.001;
        // std::cout << "kappa_i: " << kappa_i << std::endl; 
        torch::Tensor kappa_eta = torch::abs(kappa_eta0) + 0.001;
        // std::cout << "kappa_eta: " << kappa_eta << std::endl; 
        // torch::Tensor r_bulk = torch::abs(r_bulk0) + 0.001;
        // torch::Tensor r_surf = torch::abs(r_surf0) + 0.001;
        // torch::Tensor p_casc = torch::abs(p_casc0) + 0.001;
        // torch::Tensor bias = torch::abs(bias0) + 0.001;
        // torch::Tensor vg = torch::abs(vg0) + 0.001;
        torch::Tensor diff_v = torch::abs(diff_v0) + 0.001;
        // std::cout << "diff_v: " << diff_v << std::endl; 
        torch::Tensor diff_i = torch::abs(diff_i0) + 0.001;
        // std::cout << "diff_i: " << diff_i << std::endl; 
        torch::Tensor L = torch::abs(L0) + 0.001;
        // std::cout << "L: " << L << std::endl; 

        torch::Tensor lap_cv = lap(cv, dx=dx, dy=dy);
        torch::Tensor lap_ci = lap(cv, dx=dx, dy=dy);
        torch::Tensor lap_eta = lap(eta, dx=dx, dy=dy);
        torch::Tensor h = eta.sub(1).pow(2);
        torch::Tensor j = eta.pow(2);

        torch::Tensor fs = energy_v * cv + energy_i * ci + kBT * (cv * log_with_mask(cv) + ci * log_with_mask(ci) + (1 - cv - ci) * log_with_mask(1 - cv - ci));
        torch::Tensor fs_mask = (1 - cv - ci).le(eps);
        fs.masked_fill_(fs_mask, 0.0);
        torch::Tensor dfs_dcv = energy_v + kBT * (log_with_mask(cv) - log_with_mask(1 - cv - ci));
        torch::Tensor dfs_dci = energy_i + kBT * (log_with_mask(ci) - log_with_mask(1 - cv - ci));
        dfs_dcv.masked_fill_(fs_mask, 0.0);
        dfs_dci.masked_fill_(fs_mask, 0.0);

        torch::Tensor fv = (cv - 1).pow(2) + ci.pow(2);
        torch::Tensor dfv_dcv = 2 * (cv - 1);
        torch::Tensor dfv_dci = 2 * ci;

        torch::Tensor dF_dcv = h * dfs_dcv + j * dfv_dcv - kappa_v * lap_cv;
        torch::Tensor dF_dci = h * dfs_dci + j * dfv_dci - kappa_i * lap_ci;
        torch::Tensor dF_deta = _N * (fs * 2 * (eta - 1) + fv * 2 * eta - kappa_eta * lap_eta);

        torch::Tensor mv = diff_v * cv / kBT;
        torch::Tensor mi = diff_i * ci / kBT;

        torch::Tensor cv_new = cv + dt * (mv * lap(dF_dcv, dx, dy));
        torch::Tensor ci_new = ci + dt * (mi * lap(dF_dci, dx, dy));
        torch::Tensor eta_new = eta + dt * (-L * dF_deta);

        cv_new = fix_deviations(cv_new);
        ci_new = fix_deviations(ci_new);
        eta_new = fix_deviations(eta_new);

        torch::Tensor concate_vals = torch::cat({cv_new, ci_new, eta_new});

        return concate_vals;
    }
};

TORCH_MODULE_IMPL(IrradiationSingleTimeStep, IrradiationSingleTimeStepImpl);

struct Param
{
    /* data */
    torch::DeviceType device_type = torch::kCPU;
    /*
    Usage:
    if (torch::cuda::is_available()) {
        device_type = torch::kCUDA;
    } else {
        device_type = torch::kCPU;
    }
    torch::Device device(device_type);

    Net model;
    model.to(device);
    */

    double min_write_val = 1e-10;

    double lr = 1e-1;

    double lr2 = 1e-3;

    double grad_clip_val = 1e10;

    int Nx = 130;

    int Ny = 130;

    int dx = 1;

    int dy = 1;

    int nsteps = 5000;

    int nprint = 100;

    double dt = 2e-2;

    double eps = 1e-6;

    int _N = 1;

    int batch_size = 1; // 2^N

    int epoch = 1000; // 500

    char* data_path = "./output/irradiation";
    char* filename_pkl = "./all_data.pkl";

    double fluct_norm = 100.0;

    double lambda1 = 10.0;

    double lambda2 = 10.0;

    int skip_step = 30;

    int embedding_features = 8;

    int start_skip = 9;

};

