// unet
#include "cunet.h"
#include "partialconv.h"

// nanovoid sim
#include "irradiation_model.h"

// data loading
#include "load_data.h"

// torch and other utils
#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>
// #include <boost/filesystem.hpp>
#include <memory>

// time issue
#include <chrono>

// hash model
#include "nanovoid_app.h"
#include "png.h"

// namespace fs = boost::filesystem;
// namespace idx = torch::indexing;
// static idx::Slice all_select(idx::None, idx::None, idx::None);
// static idx::Slice one_one_select(1, -1, idx::None);
// static bool debug_on = true;


// torch::Tensor stitch_by_mask(torch::Tensor cv1, torch::Tensor cv_ori, torch::Tensor mask1, torch::Tensor ul1) {
//     torch::Tensor ul = ul1.unsqueeze(-1).unsqueeze(-1);
//     ul = ul.repeat({1, 128, 128});
//     torch::Tensor mask2 = mask1 * ul + (1.0 - mask1) * (1.0 - ul);
//     return cv_ori * mask2 + cv1 * (1.0 - mask2);
// }

int main(int argc, char **argv) {

    int seed;
    char* run_idx;

    if (argc > 2) {
        run_idx = argv[1];
        seed = std::stoi(argv[2]);
    }
    else {
        run_idx = "";
        seed = 4321;
    }

    char* cv_path = "../../data/irradiation_v3/cv_all_data.pt";
    char* ci_path = "../../data/irradiation_v3/ci_all_data.pt";
    char* eta_path = "../../data/irradiation_v3/eta_all_data.pt";
    char* video_path = "../../data/irradiation_v3/video_all_data.pt";
    // char* filename_pkl = "container_all_data.pt";
    // char* video_pkl = "container_video_data.pt";

    // char* best_ts_model_out = "best_ts_model_out_v3.0_ef8"
    // char* best_unet_model_out = ""
    // char* best_unet_model_pretrained = ""

    char* best_ts_model_pretrained = NULL;

    double min_loss = 1000.0;

    // torch::autograd::AnomalyMode::set_enabled(true); 

    // skip for testing
    // IrradiationVideoDataset dataset(cv_path, ci_path, eta_path, video_path, param.skip_step);

    ParameterSet para(0.0);

    para.energy_v0 = 3.90086937;
    para.energy_i0 = 0.20275441;
    para.kBT0 = -2.56907487;
    para.kappa_v0 = -0.38098839;
    para.kappa_i0 = -0.78148192;
    para.kappa_eta0 = 0.34326860;
    para.r_bulk0 = 4.99900007;
    para.r_surf0 = 9.99899960;
    para.p_casc0 = 0.00900000;
    para.bias0 = 0.29899999;
    para.vg0 = 0.00900000;
    para.diff_v0 = 1.15292716;
    para.diff_i0 = -0.19592035;
    para.L0 = -0.52126276;



    // testing speed of hash model locally
    // auto start_test = std::chrono::high_resolution_clock::now();
    // uint Nx = 128;
    // uint Ny = 128;
    // uint channel = 3;
    // uint nsteps = 1;
    // valueType lshr_test = 0.01;
    // uint lshK_test = 3;
    // uint lshL_test = 10;
    // valueType ***org_mat = read_from_data("../data_img_128", Nx, Ny);
    // NanoVoidOneStep one_step(Nx, Ny, para, lshr_test, lshK_test, lshL_test);
    // one_step.encode_from_img(org_mat);
    // one_step.next();
    // valueType*** decoded_img = one_step.decode_to_img();
    // valueType*** ground_truth = init_zero_mat(Nx, Ny, 3);
    // valueType*** dloss = get_dloss(Nx, Ny, 3, decoded_img, ground_truth);
    // NanoVoidOneBack one_back(Nx, Ny, para, lshr_test, lshK_test, lshL_test);
    // one_back.encode_from_img(decoded_img, dloss);
    // one_back.next();
    // one_back.print_derivative();
    // delete_3d_array(org_mat, Nx, Ny, 3);
    // delete_3d_array(decoded_img, Nx, Ny, 3);
    // delete_3d_array(ground_truth, Nx, Ny, 3);
    // delete_3d_array(dloss, Nx, Ny, 3);
    // auto stop_test = std::chrono::high_resolution_clock::now();
    // auto duration_test = std::chrono::duration_cast<std::chrono::milliseconds>(stop_test - start_test);
    // std::cout << "time of one forward backward: " << duration_test.count() << std::endl;


    // IrradiationSingleTimeStep ts_model(param.dt, param.dx, param.dy, param.eps, param._N, para);
    Param param;

    param.lr = 1e-1;
    param.skip_step = 1;

    param.lr2 = 1e-3;
    param.lambda1 = 10.0;
    param.lambda2 = 10.0;
    param.batch_size = 1;
    param.epoch = 1;
    param.embedding_features = 8;
    param.eps = 1e-1;

    torch::manual_seed(seed);
    torch::cuda::manual_seed(seed);
    IrradiationVideoDataset dataset(cv_path, ci_path, eta_path, video_path, param.skip_step);

    if (debug_on) {
        std::cout << "finish data loading" << std::endl;
    }

    torch::Device device(torch::kCPU);

    // ts_model->to(device);

    if (debug_on) {
        std::cout << "finish ts model init" << std::endl;
    }

    // if (debug_on) {
    //     for (const auto& pair : ts_model->named_parameters()) {
    //         std::cout << pair.key() << ": " << pair.value() << std::endl;
    //         std::cout << "is_leaf: " << pair.value().is_leaf() << std::endl;
    //     }
    // }

    // torch::nn::MSELossOptions mseopt(torch::kNone);

    // mseopt.reduction()

    torch::nn::MSELoss mse(torch::nn::MSELossOptions(torch::kSum));

    if (debug_on) {
        std::cout << "finish mseloss init" << std::endl;
    }

    // torch::optim::Adam optimizer(ts_model->parameters(), torch::optim::AdamOptions(param.lr));

    if (debug_on) {
        std::cout << "finish optim1 init" << std::endl;
    }

    // CUNet2dWithEmbeddingGen video2pf(3, 3, 32, 5, 3, true, true, true, true, false, false, param.embedding_features, dataset.get_len() + param.skip_step, 8, false);
    // for testing
    CUNet2dWithEmbeddingGen video2pf(3, 3, 32, 5, 3, true, true, true, true, false, false, param.embedding_features, dataset.get_len() + param.skip_step, 8, false);


    // video2pf.eval();
    if (debug_on) {
        std::cout << "finish video2pf init" << std::endl;
    }

    torch::optim::Adam optimizer2(video2pf->parameters(), torch::optim::AdamOptions(param.lr2));

    if (debug_on) {
        std::cout << "finish optim2 init" << std::endl;
    }

    torch::Tensor mask = torch::ones({128, 128}, torch::dtype(torch::kFloat32).requires_grad(false));

    mask.index_put_({idx::Slice(64, 128, idx::None), all_select}, 0.0);

    if (debug_on) {
        std::cout << "start training..." << std::endl;
    }

    for (int i = 0; i < param.epoch; ++ i) {
        double loss = 0.0;
        int total_size = 0;
        printf("epoch:\t%d\n", i);
        // if (debug_on) {
        //     std::cout << "len of data: " << dataset.get_len() << std::endl;
        // }
        for (int index = param.start_skip + 1; index < (dataset.get_len() - param.start_skip - param.skip_step*2); ++ index) { // should be batch in loader, iterate all data point in train set
            
            ReturnItem rt = dataset.get_item(index);
            
            // ground truth at time 0
            torch::Tensor cv = rt.rd.cv;
            torch::Tensor ci = rt.rd.ci;
            torch::Tensor eta = rt.rd.eta;

            torch::Tensor frame1 = rt.rd.v;
            // torch::Tensor indicies1 = torch::from_blob(&rt.rd.index, {1}, torch::kInt64);
            int indicies1 = rt.rd.index;

            torch::Tensor ul1 = rt.rd.ul;

            if (debug_on) {
                std::cout << "success get data" << std::endl;
                // std::cout << "frame1 size: " << frame1.sizes() << std::endl;
            }

            // torch::Tensor pf = video2pf->forward(frame1, indicies1);

            // if (debug_on) {
            //     std::cout << "success forward in unet" << std::endl;
            // }

            // // learned cv ci eta from unet
            // torch::Tensor cv = pf.index({all_select, 0, all_select, all_select});
            // torch::Tensor ci = pf.index({all_select, 1, all_select, all_select});
            // torch::Tensor eta = pf.index({all_select, 2, all_select, all_select});
            // std::cout << "cv size: " << cv.sizes() << ", ci size: " << ci.sizes() << ", eta size: " << eta.sizes() << std::endl;
            // // should be 1,128,128

            // eta = stitch_by_mask(eta, eta1, mask, ul1);

            // if (debug_on) {
            //     std::cout << "success get learned cv ci eta from unet" << std::endl;
            //     std::cout << "dtype of cv ci eta: " << cv.dtype() << ci.dtype() << eta.dtype() << std::endl;
            // }

            valueType* cv_data = cv.data_ptr<valueType>();
            // valueType* cv_data = cv.data_ptr
            valueType* ci_data = ci.data_ptr<valueType>();
            valueType* eta_data = eta.data_ptr<valueType>();

            valueType lshr = 0.01;
            uint lshK = 3;
            uint lshL = 10;
            int img_size = 128;
            NanoVoidOneStep one_step(img_size, img_size, para, lshr, lshK, lshL);
            // printf("before encode_from_img\n");
            one_step.encode_from_img_torch(cv_data, ci_data, eta_data);
            // printf("finish encode_from_img\n");

            auto start = std::chrono::high_resolution_clock::now();
            for (int j = 0; j < param.skip_step; ++j) {
                one_step.next();
                std::cout << "sim step: " << j << std::endl;
            }
            auto stop = std::chrono::high_resolution_clock::now();
            auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);
            std::cout << "time of ts model forward: " << duration.count() << "ms in " << param.skip_step << "steps" << std::endl;
            // torch::Tensor concate_vals;
            // for (int j = 0; j < param.skip_step; ++ j) {
            //     auto start = std::chrono::high_resolution_clock::now();
            //     concate_vals = ts_model->forward(cv, ci, eta);
            //     auto stop = std::chrono::high_resolution_clock::now();
            //     auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
            //     std::cout << "time of one forward: " << duration.count() << std::endl;
            //     cv = concate_vals.index({0});
            //     ci = concate_vals.index({1});
            //     eta = concate_vals.index({2});
            //     cv.unsqueeze_(0);
            //     ci.unsqueeze_(0);
            //     eta.unsqueeze_(0);
            //     std::cout << "concate_vals size: " << concate_vals.sizes() << std::endl;
            //     // std::cout << "cv size: " << cv.sizes() << ", ci size: " << ci.sizes() << ", eta size: " << eta.sizes() << std::endl;
            // }
            valueType** contate_vals = one_step.decode_to_img_torch();
            valueType* cv_sim = contate_vals[0];
            valueType* ci_sim = contate_vals[1];
            valueType* eta_sim = contate_vals[2];
            torch::Tensor cv_sim_t = torch::from_blob(cv_sim, {1, 128, 128}, torch::dtype(torch::kFloat)).clone();
            // std::cout << "cv sim t sum: " << cv_sim_t.sum().item<valueType>() << std::endl; 
            torch::Tensor ci_sim_t = torch::from_blob(ci_sim, {1, 128, 128}, torch::dtype(torch::kFloat)).clone();
            // std::cout << "ci sim t sum: " << ci_sim_t.sum().item<valueType>() << std::endl; 
            torch::Tensor eta_sim_t = torch::from_blob(eta_sim, {1, 128, 128}, torch::dtype(torch::kFloat)).clone();
            // std::cout << "eta sim t sum: " << eta_sim_t.sum().item<valueType>() << std::endl; 
            // torch::Tensor grad_t = torch::from_blob(grad, {3,3}, torch::dtype(torch::kFloat64)).clone();

            if (debug_on) {
                std::cout << "success forward in ts model" << std::endl;
            }

            torch::Tensor cv_ref = rt.rl.cv_ref;
            torch::Tensor ci_ref = rt.rl.ci_ref;
            torch::Tensor eta_ref = rt.rl.eta_ref;

            torch::Tensor frame2 = rt.rl.v_ref;

            // torch::Tensor indicies2 = torch::from_blob(&rt.rl.index_ref, {1}, torch::kInt);
            int indicies2 = rt.rl.index_ref;

            torch::Tensor ul2 = rt.rl.ul_ref;

            // pf = video2pf(frame2, indicies2);

            // if (debug_on) {
            //     std::cout << "success forward in unet, ref" << std::endl;
            // }

            // torch::Tensor cv_frame = pf.index({all_select, 0, all_select, all_select});
            // torch::Tensor ci_frame = pf.index({all_select, 1, all_select, all_select});
            // torch::Tensor eta_frame = pf.index({all_select, 2, all_select, all_select});

            if (debug_on) {
                std::cout << "success get cv ci eta ref" << std::endl;
            }

            ul2.unsqueeze_(-1).unsqueeze_(-1);
            ul2 = ul2.repeat({1, 128, 128});
            torch::Tensor mask2 = mask * ul2 + (1.0 - mask) * (1.0 - ul2);

            // torch::Tensor cv_new = concate_vals.index({0});
            // torch::Tensor ci_new = concate_vals.index({1});
            // torch::Tensor eta_new = concate_vals.index({2});

            // cv_new.unsqueeze_(0);
            // ci_new.unsqueeze_(0);
            // eta_new.unsqueeze_(0);

            // std::cout << "cv frame size: " << cv_frame.sizes() << ", cv new size: " << cv_new.sizes() << std::endl;
            // std::cout << "ci frame size: " << ci_frame.sizes() << ", ci new size: " << ci_new.sizes() << std::endl;
            // std::cout << "eta frame size: " << eta_frame.sizes() << ", eta new size: " << eta_new.sizes() << ", eta ref size: " << eta_ref.sizes() << std::endl;
            // should be 1,128,128


            // torch::Tensor cv_batch_loss = param.lambda2 * mse->forward(cv_frame, cv_sim_t);
            torch::Tensor cv_batch_loss = param.lambda2 * 2 *(cv_ref - cv_sim_t);
            // torch::Tensor ci_batch_loss = param.lambda2 * mse->forward(ci_frame, ci_sim_t);
            torch::Tensor ci_batch_loss = param.lambda2 * 2 * (ci_ref - ci_sim_t);
            // torch::Tensor eta_batch_loss = mse->forward(mask2 * eta_ref, mask2 * eta_sim_t) + \
            //              param.lambda1 * mse->forward(mask2 * eta_ref, mask2 * eta_frame) + \
            //              param.lambda2 * mse->forward(eta_frame, eta_sim_t);
            torch::Tensor eta_batch_loss = param.lambda2 * 2 * (eta_ref - eta_sim_t);          
            
            torch::Tensor cv_batch_loss_sum = mse->forward(cv_ref, cv_sim_t);
            torch::Tensor ci_batch_loss_sum = mse->forward(ci_ref, ci_sim_t);
            torch::Tensor eta_batch_loss_sum = mse->forward(eta_ref, eta_sim_t);
            torch::Tensor batch_loss = param.lambda2 * (cv_batch_loss_sum + ci_batch_loss_sum + eta_batch_loss_sum);

            valueType* dloss_cv = cv_batch_loss.data_ptr<valueType>();
            valueType* dloss_ci = ci_batch_loss.data_ptr<valueType>();
            valueType* dloss_eta = eta_batch_loss.data_ptr<valueType>();

            if (debug_on) {
                std::cout << "success get loss" << std::endl;
            }

            // optimizer.zero_grad();
            // optimizer2.zero_grad();

            if (debug_on) {
                std::cout << "success opt1 opt2 zero grad" << std::endl;
            }

            // batch_loss.backward();
            valueType lshr_back = 0.01;
            uint lshK_back = 3;
            uint lshL_back = 10;
            NanoVoidOneBack one_back(img_size, img_size, para, lshr_back, lshK_back, lshL_back);
            // printf("before encode_from_img\n");
            one_back.encode_from_img_torch(cv_sim, ci_sim, eta_sim, dloss_cv, dloss_ci, dloss_eta);
            // printf("finish encode_from_img\n");

            auto start_back = std::chrono::high_resolution_clock::now();
            for (int j = 0; j < param.skip_step; ++j) {
                // auto start = std::chrono::high_resolution_clock::now();
                one_back.next();
                // auto stop = std::chrono::high_resolution_clock::now();
                // auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);
                // std::cout << "time of one forward: " << duration.count() << std::endl;
                std::cout << "sim back step: " << j << std::endl;
            }
            auto stop_back = std::chrono::high_resolution_clock::now();
            auto duration_back = std::chrono::duration_cast<std::chrono::milliseconds>(stop_back - start_back);
            std::cout << "time of ts model backward: " << duration_back.count() << "ms in " << param.skip_step << "steps" << std::endl;

            valueType** contate_dev = one_back.decode_to_img_torch();
            valueType* dloss_cv_n = contate_dev[3];
            valueType* dloss_ci_n = contate_dev[4];
            valueType* dloss_eta_n = contate_dev[5];

            torch::Tensor dloss_cv_n_t = torch::from_blob(dloss_cv_n, {1, 128, 128}, torch::dtype(torch::kFloat32)).clone();
            torch::Tensor dloss_ci_n_t = torch::from_blob(dloss_ci_n, {1, 128, 128}, torch::dtype(torch::kFloat32)).clone();
            torch::Tensor dloss_eta_n_t = torch::from_blob(dloss_eta_n, {1, 128, 128}, torch::dtype(torch::kFloat32)).clone();

            // cv.backward(dloss_cv_n_t);
            // ci.backward(dloss_ci_n_t);
            // eta.backward(dloss_eta_n_t);

            // cv_frame.backward(cv_batch_loss);
            // ci_frame.backward(ci_batch_loss);
            // eta_frame.backward(eta_batch_loss);

            valueType* tsm_back_grad = one_back.decode_derivative();

            if (debug_on) {
                std::cout << "success batch loss backward" << std::endl;
            }

            // optimizer.step();
            // optimizer2.step();

            double tsm_lr = param.lr;
            para.energy_v0 -= tsm_lr * tsm_back_grad[0];
            para.energy_i0 -= tsm_lr * tsm_back_grad[1];
            para.kBT0 -= tsm_lr * tsm_back_grad[2];
            para.kappa_v0 -= tsm_lr * tsm_back_grad[3];
            para.kappa_i0 -= tsm_lr * tsm_back_grad[4];
            para.kappa_eta0 -= tsm_lr * tsm_back_grad[5];
            para.diff_v0 -= tsm_lr * tsm_back_grad[6];
            para.diff_i0 -= tsm_lr * tsm_back_grad[7];
            para.L0 -= tsm_lr * tsm_back_grad[8];

            if (debug_on) {
                std::cout << "success opt1 opt2 step()" << std::endl;
            }

            int this_size = cv.size(0);
            loss += (batch_loss.item<valueType>());
            if (true) {
                std::cout << "--------------loss-----------------" << std::endl;
                std::cout << "batch loss: " << (batch_loss.item<valueType>()) << std::endl;
                std::cout << "cv loss: " << cv_batch_loss_sum.item<valueType>() << ", ci loss: " \ 
                            << ci_batch_loss_sum.item<valueType>() << ", eta loss: " << eta_batch_loss_sum.item<valueType>() << std::endl;
                std::cout << "--------------grad-----------------" << std::endl;
                one_back.print_derivative();
                std::cout << "--------------param----------------" << std::endl;
                std::cout << para.energy_v0 << std::endl;
                std::cout << para.energy_i0 << std::endl;
                std::cout << para.kBT0 << std::endl;
                std::cout << para.kappa_v0 << std::endl;
                std::cout << para.kappa_i0 << std::endl;
                std::cout << para.kappa_eta0 << std::endl;
                std::cout << para.diff_v0 << std::endl;
                std::cout << para.diff_i0 << std::endl;
                std::cout << para.L0 << std::endl;
            }
            total_size += this_size;

            // delete cv_data;
            // delete ci_data;
            // delete eta_data;
            delete cv_sim;
            delete ci_sim;
            delete eta_sim;
            delete contate_vals;
            // delete dloss_cv;
            // delete dloss_ci;
            // delete dloss_eta;
            delete contate_dev[0];
            delete contate_dev[1];
            delete contate_dev[2];
            delete contate_dev[3];
            delete contate_dev[4];
            delete contate_dev[5];
            delete contate_dev;
            delete tsm_back_grad;
        }

        loss /= total_size;
        printf("loss:\t%.8f\n", loss);

        if (loss < min_loss) {
            min_loss = loss;
            // save ts_model
            // save video2pf 
        }
        else {
            printf("Above min_loss\n");
        }
    }

    // testing
    valueType lshr_test = 0.1;
    uint lshK_test = 1;
    uint lshL_test = 1;
    int img_size = 128;
    NanoVoidOneStep one_step(img_size, img_size, para, lshr_test, lshK_test, lshL_test);


    return 0;
}

