#include "grain_growth.hpp"
#include "cxxopts.hpp"

#include <ctime>

#include <chrono>

bool debug_on = false;


valueType *read_init_state(const char *filename, uint &Nx, uint &Ny, uint &n_grains, uint &n_step)
{
    FILE *inp = fopen(filename, "r");
    fscanf(inp, "%u,%u,%u", &Nx, &Ny, &n_grains, &n_step);
    valueType *mtx = new valueType[Nx * Ny * n_grains];
    for (uint pg = 0; pg < n_grains; ++pg)
    {
        for (uint x = 0; x < Nx; ++x)
        {
            fscanf(inp, "%lf", mtx + pg * Nx * Ny + x * Ny);
            for (uint y = 1; y < Ny; ++y)
                fscanf(inp, ",%lf", mtx + pg * Nx * Ny + x * Ny + y);
        }
    }

    fclose(inp);
    return mtx;
}

void print_img_in_csv(valueType *img, const char *filename, uint Nx, uint Ny,
                      uint n_grains)
{
    FILE *oup = fopen(filename, "w");
    fprintf(oup, "%u,%u,%u\n", Nx, Ny, n_grains);
    for (uint pg = 0; pg < n_grains; ++pg)
    {
        for (uint i = 0; i < Nx; ++i)
        {
            fprintf(oup, "%lf", img[pg * Nx * Ny + i * Ny]);
            for (uint j = 1; j < Ny; ++j)
            {
                fprintf(oup, ",%lf", img[pg * Nx * Ny + i * Ny + j]);
            }
            fputc('\n', oup);
        }
    }
    fclose(oup);
}

class Args
{
public:
    uint nsteps;
    string input;
    string output;
    string bucket_output;
    uint lshL, lshK; // Nx, Ny;
    double lshr;
};

Args *parse_args(int argc, const char *argv[])
{
    try
    {
        Args *args = new Args;

        cxxopts::Options options(argv[0], " - test forward simulation of grain growth.");
        options
            .positional_help("[optional args]")
            .show_positional_help();

        options
            .set_width(70)
            .set_tab_expansion()
            .allow_unrecognised_options()
            .add_options()("s,nsteps", "Number of steps of simulation (default=100)", cxxopts::value<int>(), "N")("o,output", "Output file (default=grain.out)", cxxopts::value<std::string>(), "FILE")("i,input", "Input file (default=grain.in)", cxxopts::value<std::string>(), "FILE")("lshK", "K for LSH (default=1)", cxxopts::value<int>(), "INT")("lshL", "L for LSH (default=1)", cxxopts::value<int>(), "INT")("lshr", "r for LSH (default=1e-4)", cxxopts::value<float>(), "FLOAT")("bucket_output", "Output file of the bucket information (default=bucket.out)", cxxopts::value<std::string>(), "FILE")("h,help", "Print help")
#ifdef CXXOPTS_USE_UNICODE
                ("unicode", u8"A help option with non-ascii: à. Here the size of the"
                            " string should be correct")
#endif
            ;
        //("Nx", "size of x-axis (default=64)", cxxopts::value<int>(), "INT")
        //("Ny", "size of y-axis (default=64)", cxxopts::value<int>(), "INT")

        auto result = options.parse(argc, argv);

        if (result.count("help"))
        {
            std::cout << options.help({"", "Group"}) << std::endl;
            exit(0);
        }

        std::cout << "[Parse Args]" << std::endl;

        if (result.count("nsteps"))
        {
            std::cout << "  nsteps = " << result["nsteps"].as<int>() << std::endl;
            args->nsteps = (uint)result["nsteps"].as<int>();
        }
        else
        {
            args->nsteps = 100;
        }

        if (result.count("output"))
        {
            std::cout << "  output = " << result["output"].as<std::string>()
                      << std::endl;
            args->output = result["output"].as<std::string>();
        }
        else
        {
            args->output = "grain.out";
        }

        if (result.count("input"))
        {
            std::cout << "  input = " << result["input"].as<std::string>()
                      << std::endl;
            args->input = result["input"].as<std::string>();
        }
        else
        {
            args->input = "grain.in";
        }

        if (result.count("bucket_output"))
        {
            std::cout << "  bucket_output = " << result["bucket_output"].as<std::string>()
                      << std::endl;
            args->bucket_output = result["bucket_output"].as<std::string>();
        }
        else
        {
            args->bucket_output = "bucket.out";
        }
        if (result.count("lshK"))
        {
            std::cout << "  lshK = " << result["lshK"].as<int>()
                      << std::endl;
            args->lshK = (uint)result["lshK"].as<int>();
        }
        else
        {
            args->lshK = 1;
        }

        if (result.count("lshL"))
        {
            std::cout << "  lshL = " << result["lshL"].as<int>()
                      << std::endl;
            args->lshL = (uint)result["lshL"].as<int>();
        }
        else
        {
            args->lshL = 1;
        }

        if (result.count("lshr"))
        {
            std::cout << "  lshr = " << result["lshr"].as<float>()
                      << std::endl;
            args->lshr = (double)result["lshr"].as<float>();
        }
        else
        {
            args->lshr = 1e-4;
        }

        auto arguments = result.arguments();
        std::cout << "  Saw " << arguments.size() << " arguments" << std::endl;

        std::cout << "[End of Parse Args]" << std::endl;

        /*
    if (result.count("Nx"))
    {
      std::cout << "  Nx = " << result["Nx"].as<int>()
        << std::endl;
      args->Nx = (uint)result["Nx"].as<int>();
    }else{
      args->Nx = 64;
    }
    if (result.count("Ny"))
    {
      std::cout << "  Ny = " << result["Ny"].as<int>()
        << std::endl;
      args->Ny = (uint)result["Ny"].as<int>();
    }else{
      args->Ny = 64;
    }
    */

        return args;
    }
    catch (const cxxopts::OptionException &e)
    {
        std::cout << "error parsing options: " << e.what() << std::endl;
        exit(1);
    }
}

int main(int argc, const char *argv[])
{
    

    Args *args = parse_args(argc, argv);

    // def parameters
    uint Nx = 64; //1024;   these will be changed later.
    uint Ny = 64; //1024;
    uint n_grains = 2;
    uint n_step = 500;

    uint lshK = args->lshK;
    uint lshL = args->lshL;
    valueType lsh_r = args->lshr;
    uint nsteps = args->nsteps;

    valueType h = 0.5;

    // valueType A = 1.0;
    // valueType B = 1.0;
    // valueType L = 5.0;
    // valueType kappa = 0.1;

    valueType dtime = 0.05;
    valueType ttime = 0.0;

    valueType init_L = 2.0; // try to learn to 5.0
    valueType init_A = 2.0; // try to learn to 1.0
    valueType init_B = 3.0; // try to learn to 1.0
    valueType init_kappa = 0.9; // try to learn to 0.1

    double lr = 1e-5;
    uint start_skip = 1;
    uint skip_step = 30;
    // uint skip_step = 5;
    uint epoch = nsteps;

    char* data_path = "../data/grain_growth_all_data_1";
    GrainGrowthDataset dataset(data_path, start_skip, skip_step);

    Nx = dataset.Nx;
    Ny = dataset.Ny;
    n_grains = dataset.n_grains;
    n_step = dataset.n_step;

    if (debug_on) {
        std::cout << "finish data loading" << std::endl;
    }

    double min_loss = 1000.0;

    // epoch = 1;

    for (int i = 0; i < epoch; ++ i) {
        double loss = 0.0;
        int total_size = 0;
        printf("epoch:\t%d\n", i);

        // for (int index = start_skip; index < start_skip + 10; ++index) {        // for testing
        for (int index = start_skip; index < dataset.get_len() - 10; ++index) {
            ReturnItem rt = dataset.get_item(index);

            valueType* eta1_eta2_start = rt.data.eta1_eta2;
            valueType lshr = 0.01;
            uint lshK = 3;
            uint lshL = 10;
            int img_size = dataset.Nx;
            valueType h = 0.5;
            valueType dtime = 0.05;
            valueType ttime = 0.0;
            uint eta1_eta2_len = img_size * img_size * n_grains;

            // if (debug_on) {
            //     std::cout << "sum of eta1_eta2_start: " << sum_mtx(eta1_eta2_start, eta1_eta2_len) << std::endl;
            // }

            GrainGrowthOneStep one_step(img_size, img_size, n_grains, lshK, lshL, h,\
                                     init_A, init_B, init_L, init_kappa, dtime, lshr);
            one_step.encode_from_img(eta1_eta2_start);

            auto start = std::chrono::high_resolution_clock::now();
            for (int j = 0; j < skip_step; ++j) {
                one_step.next();
                // std::cout << "sim step: " << j << std::endl;
                if (j % 10 == 0)
                {
                    one_step.hash_t.clean_up_l_list();
                }
            }
            auto stop = std::chrono::high_resolution_clock::now();
            auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);
            std::cout << "time of ts model forward: " << duration.count() << "ms in " << skip_step << "steps" << std::endl;

            if (debug_on) {
                std::cout << "success forward in ts model" << std::endl;
            }

            valueType* eta1_eta2_sim = one_step.decode_to_img();

            // if (debug_on) {
            //     std::cout << "sum of eta1_eta2_sim: " << sum_mtx(eta1_eta2_sim, eta1_eta2_len) << std::endl;
            // }

            valueType* eta1_eta2_ref = rt.ref.eta1_eta2_ref;

            // if (debug_on) {
            //     std::cout << "sum of eta1_eta2_ref: " << sum_mtx(eta1_eta2_ref, eta1_eta2_len) << std::endl;
            // }

            valueType dloss[eta1_eta2_len];

            calculate_mse_loss(eta1_eta2_sim, eta1_eta2_ref, dloss, eta1_eta2_len);

            // if (debug_on) {
            //     std::cout << "sum of dloss: " << sum_mtx(dloss, eta1_eta2_len) << std::endl;
            // }

            if (debug_on) {
                std::cout << "success get loss" << std::endl;
            }

            valueType lshr_back = 0.01;
            uint lshK_back = 3;
            uint lshL_back = 10;

            GrainGrowthOneBack one_back(img_size, img_size, n_grains, lshK_back, lshL_back, h, \
                                         init_A, init_B, init_L, init_kappa, dtime, lshr_back);

            one_back.encode_from_img(eta1_eta2_sim, dloss);

            auto start_back = std::chrono::high_resolution_clock::now();
            for (int j = 0; j < skip_step; ++j) {
                // auto start = std::chrono::high_resolution_clock::now();
                one_back.next();
                // auto stop = std::chrono::high_resolution_clock::now();
                // auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);
                // std::cout << "time of one forward: " << duration.count() << std::endl;
                // std::cout << "sim back step: " << j << std::endl;
            }
            auto stop_back = std::chrono::high_resolution_clock::now();
            auto duration_back = std::chrono::duration_cast<std::chrono::milliseconds>(stop_back - start_back);
            std::cout << "time of ts model backward: " << duration_back.count() << "ms in " << skip_step << "steps" << std::endl;

            valueType* tsm_back_grad = one_back.decode_derivative(); // L A B kappa

            if (debug_on) {
                std::cout << "success batch loss backward" << std::endl;
            }

            init_L -= lr * (std::abs(tsm_back_grad[0])>(10/lr)?0:tsm_back_grad[0]);
            init_A -= lr * (std::abs(tsm_back_grad[1])>(10/lr)?0:tsm_back_grad[1]);
            init_B -= lr * (std::abs(tsm_back_grad[2])>(10/lr)?0:tsm_back_grad[2]);
            init_kappa -= lr * (std::abs(tsm_back_grad[3])>(10/lr)?0:tsm_back_grad[3]);

            // init_L -= lr * (tsm_back_grad[0]>(10/lr)?0:tsm_back_grad[0]);
            // init_A -= lr * (tsm_back_grad[1]>(10/lr)?0:tsm_back_grad[1]);
            // init_B -= lr * (tsm_back_grad[2]>(10/lr)?0:tsm_back_grad[2]);
            // init_kappa -= lr * (tsm_back_grad[3]>(10/lr)?0:tsm_back_grad[3]);

            if (debug_on) {
                std::cout << "success opt1 opt2 step()" << std::endl;
            }

            int this_size = 1;
            valueType batch_loss = sum_mse_loss(eta1_eta2_sim, eta1_eta2_ref, img_size * img_size * n_grains);
            loss += batch_loss;
            if (true) {
                std::cout << "--------------loss-----------------" << std::endl;
                std::cout << "batch loss: " << batch_loss << std::endl;
                std::cout << "--------------grad-----------------" << std::endl;
                one_back.print_derivative();
                std::cout << "--------------param----------------" << std::endl;
                std::cout << init_L<< std::endl;
                std::cout << init_A << std::endl;
                std::cout << init_B << std::endl;
                std::cout << init_kappa << std::endl;
            }
            total_size += this_size;

            delete eta1_eta2_sim;
            delete tsm_back_grad;
        }
    }
    
    printf("Have a nice day!\n");
    return 0;
}
