#include "grain_growth.hpp"
#include "cxxopts.hpp"

#include <ctime>

#include <chrono>

bool debug_on = false;


valueType *read_init_state(const char *filename, uint &Nx, uint &Ny, uint &n_grains, uint &n_step)
{
    FILE *inp = fopen(filename, "r");
    fscanf(inp, "%u,%u,%u", &Nx, &Ny, &n_grains, &n_step);
    valueType *mtx = new valueType[Nx * Ny * n_grains];
    for (uint pg = 0; pg < n_grains; ++pg)
    {
        for (uint x = 0; x < Nx; ++x)
        {
            fscanf(inp, "%lf", mtx + pg * Nx * Ny + x * Ny);
            for (uint y = 1; y < Ny; ++y)
                fscanf(inp, ",%lf", mtx + pg * Nx * Ny + x * Ny + y);
        }
    }

    fclose(inp);
    return mtx;
}

void print_img_in_csv(valueType *img, const char *filename, uint Nx, uint Ny,
                      uint n_grains)
{
    FILE *oup = fopen(filename, "w");
    fprintf(oup, "%u,%u,%u\n", Nx, Ny, n_grains);
    for (uint pg = 0; pg < n_grains; ++pg)
    {
        for (uint i = 0; i < Nx; ++i)
        {
            fprintf(oup, "%lf", img[pg * Nx * Ny + i * Ny]);
            for (uint j = 1; j < Ny; ++j)
            {
                fprintf(oup, ",%lf", img[pg * Nx * Ny + i * Ny + j]);
            }
            fputc('\n', oup);
        }
    }
    fclose(oup);
}

class Args
{
public:
    uint nsteps;
    string input;
    string output;
    string bucket_output;
    uint lshL, lshK; // Nx, Ny;
    double lshr;
};

Args *parse_args(int argc, const char *argv[])
{
    try
    {
        Args *args = new Args;

        cxxopts::Options options(argv[0], " - test forward simulation of grain growth.");
        options
            .positional_help("[optional args]")
            .show_positional_help();

        options
            .set_width(70)
            .set_tab_expansion()
            .allow_unrecognised_options()
            .add_options()("s,nsteps", "Number of steps of simulation (default=100)", cxxopts::value<int>(), "N")("o,output", "Output file (default=grain.out)", cxxopts::value<std::string>(), "FILE")("i,input", "Input file (default=grain.in)", cxxopts::value<std::string>(), "FILE")("lshK", "K for LSH (default=1)", cxxopts::value<int>(), "INT")("lshL", "L for LSH (default=1)", cxxopts::value<int>(), "INT")("lshr", "r for LSH (default=1e-4)", cxxopts::value<float>(), "FLOAT")("bucket_output", "Output file of the bucket information (default=bucket.out)", cxxopts::value<std::string>(), "FILE")("h,help", "Print help")
#ifdef CXXOPTS_USE_UNICODE
                ("unicode", u8"A help option with non-ascii: à. Here the size of the"
                            " string should be correct")
#endif
            ;
        //("Nx", "size of x-axis (default=64)", cxxopts::value<int>(), "INT")
        //("Ny", "size of y-axis (default=64)", cxxopts::value<int>(), "INT")

        auto result = options.parse(argc, argv);

        if (result.count("help"))
        {
            std::cout << options.help({"", "Group"}) << std::endl;
            exit(0);
        }

        std::cout << "[Parse Args]" << std::endl;

        if (result.count("nsteps"))
        {
            std::cout << "  nsteps = " << result["nsteps"].as<int>() << std::endl;
            args->nsteps = (uint)result["nsteps"].as<int>();
        }
        else
        {
            args->nsteps = 100;
        }

        if (result.count("output"))
        {
            std::cout << "  output = " << result["output"].as<std::string>()
                      << std::endl;
            args->output = result["output"].as<std::string>();
        }
        else
        {
            args->output = "grain.out";
        }

        if (result.count("input"))
        {
            std::cout << "  input = " << result["input"].as<std::string>()
                      << std::endl;
            args->input = result["input"].as<std::string>();
        }
        else
        {
            args->input = "grain.in";
        }

        if (result.count("bucket_output"))
        {
            std::cout << "  bucket_output = " << result["bucket_output"].as<std::string>()
                      << std::endl;
            args->bucket_output = result["bucket_output"].as<std::string>();
        }
        else
        {
            args->bucket_output = "bucket.out";
        }
        if (result.count("lshK"))
        {
            std::cout << "  lshK = " << result["lshK"].as<int>()
                      << std::endl;
            args->lshK = (uint)result["lshK"].as<int>();
        }
        else
        {
            args->lshK = 1;
        }

        if (result.count("lshL"))
        {
            std::cout << "  lshL = " << result["lshL"].as<int>()
                      << std::endl;
            args->lshL = (uint)result["lshL"].as<int>();
        }
        else
        {
            args->lshL = 1;
        }

        if (result.count("lshr"))
        {
            std::cout << "  lshr = " << result["lshr"].as<float>()
                      << std::endl;
            args->lshr = (double)result["lshr"].as<float>();
        }
        else
        {
            args->lshr = 1e-4;
        }

        auto arguments = result.arguments();
        std::cout << "  Saw " << arguments.size() << " arguments" << std::endl;

        std::cout << "[End of Parse Args]" << std::endl;

        /*
    if (result.count("Nx"))
    {
      std::cout << "  Nx = " << result["Nx"].as<int>()
        << std::endl;
      args->Nx = (uint)result["Nx"].as<int>();
    }else{
      args->Nx = 64;
    }
    if (result.count("Ny"))
    {
      std::cout << "  Ny = " << result["Ny"].as<int>()
        << std::endl;
      args->Ny = (uint)result["Ny"].as<int>();
    }else{
      args->Ny = 64;
    }
    */

        return args;
    }
    catch (const cxxopts::OptionException &e)
    {
        std::cout << "error parsing options: " << e.what() << std::endl;
        exit(1);
    }
}

int main(int argc, const char *argv[])
{
    
    Args *args = parse_args(argc, argv);

    // def parameters
    uint Nx = 64; //1024;   these will be changed later.
    uint Ny = 64; //1024;
    uint n_grains = 2;
    uint n_step = 500;

    uint lshK = args->lshK;
    uint lshL = args->lshL;
    valueType lsh_r = args->lshr;
    uint nsteps = args->nsteps;

    valueType h = 0.5;

    valueType A = 1.0;
    valueType B = 1.0;
    valueType L = 5.0;
    valueType kappa = 0.1;

    valueType dtime = 0.05;
    valueType ttime = 0.0;

    valueType init_L = 5.0; // try to learn to 5.0
    valueType init_A = 1.0; // try to learn to 1.0
    valueType init_B = 1.0; // try to learn to 1.0
    valueType init_kappa = 0.1; // try to learn to 0.1

    double lr = 1e-5;
    uint start_skip = 1;
    uint skip_step = 30;
    // uint skip_step = 5;
    uint epoch = 500;

    char* data_path = "../grain_growth_all_data";
    GrainGrowthDataset dataset(data_path, start_skip, skip_step);

    Nx = dataset.Nx;
    Ny = dataset.Ny;
    n_grains = dataset.n_grains;
    n_step = dataset.n_step;

    if (debug_on) {
        std::cout << "finish data loading" << std::endl;
    }

    valueType* mtx = read_init_state_1("../grain_growth_init_cond_1", Nx, Ny, n_grains);

    double min_loss = 1000.0;

    GrainGrowthOneStep one_step(Nx, Ny, n_grains, lshK, lshL, h, \
                                A, B, L, kappa, dtime, lsh_r);

    one_step.encode_from_img(mtx);
    one_step.hash_t.clean_up_l_list();
    std::string filename2 = std::string("../gg_intuition/hb_") + std::to_string(0);
    one_step.print_PNBuckets_to_file(filename2.c_str());
    for (uint step = 0; step < 1700;)
    {
        std::cout << "step: " << step << std::endl;
        one_step.next();

        ttime += dtime;
        ++step;

        if (step % 10 == 0)
        {
            printf("step=%u, ttime=%lf\n", step, ttime);
            one_step.hash_t.clean_up_l_list();
        }
        if (step % 50 == 0 || step == 1){
            valueType* end_result = one_step.decode_to_img();
            std::string filename = std::string("../gg_intuition/hf_") + std::to_string(step);
            filename2 = std::string("../gg_intuition/hb_") + std::to_string(step);
            print_img_in_csv(end_result, filename.c_str(), Nx, Ny, n_grains);
            one_step.print_PNBuckets_to_file(filename2.c_str());
            delete [] end_result;
        }
    }
    delete [] mtx;
    delete [] args;
    // epoch = 1;

    // for (int i = 0; i < epoch; ++ i) {
    //     double loss = 0.0;
    //     int total_size = 0;
    //     printf("epoch:\t%d\n", i);

    //     // for (int index = start_skip; index < start_skip + 10; ++index) {        // for testing
    //     for (int index = start_skip; index < dataset.get_len() - 10; ++index) {
    //         ReturnItem rt = dataset.get_item(index);

    //         valueType* eta1_eta2_start = rt.data.eta1_eta2;
    //         valueType lshr = 0.01;
    //         uint lshK = 3;
    //         uint lshL = 10;
    //         int img_size = dataset.Nx;
    //         valueType h = 0.5;
    //         valueType dtime = 0.05;
    //         valueType ttime = 0.0;
    //         uint eta1_eta2_len = img_size * img_size * n_grains;

    //         // if (debug_on) {
    //         //     std::cout << "sum of eta1_eta2_start: " << sum_mtx(eta1_eta2_start, eta1_eta2_len) << std::endl;
    //         // }

    //         GrainGrowthOneStep one_step(img_size, img_size, n_grains, lshK, lshL, h,\
    //                                  init_A, init_B, init_L, init_kappa, dtime, lshr);
    //         one_step.encode_from_img(eta1_eta2_start);

    //         auto start = std::chrono::high_resolution_clock::now();
    //         for (int j = 0; j < skip_step; ++j) {
    //             one_step.next();
    //             // std::cout << "sim step: " << j << std::endl;
    //         }
    //         auto stop = std::chrono::high_resolution_clock::now();
    //         auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);
    //         std::cout << "time of ts model forward: " << duration.count() << "ms in " << skip_step << "steps" << std::endl;

    //         if (debug_on) {
    //             std::cout << "success forward in ts model" << std::endl;
    //         }

    //         valueType* eta1_eta2_sim = one_step.decode_to_img();

    //         // if (debug_on) {
    //         //     std::cout << "sum of eta1_eta2_sim: " << sum_mtx(eta1_eta2_sim, eta1_eta2_len) << std::endl;
    //         // }

    //         valueType* eta1_eta2_ref = rt.ref.eta1_eta2_ref;

    //         // if (debug_on) {
    //         //     std::cout << "sum of eta1_eta2_ref: " << sum_mtx(eta1_eta2_ref, eta1_eta2_len) << std::endl;
    //         // }

    //         valueType dloss[eta1_eta2_len];

    //         calculate_mse_loss(eta1_eta2_sim, eta1_eta2_ref, dloss, eta1_eta2_len);

    //         // if (debug_on) {
    //         //     std::cout << "sum of dloss: " << sum_mtx(dloss, eta1_eta2_len) << std::endl;
    //         // }

    //         if (debug_on) {
    //             std::cout << "success get loss" << std::endl;
    //         }

    //         valueType lshr_back = 0.01;
    //         uint lshK_back = 3;
    //         uint lshL_back = 10;

    //         GrainGrowthOneBack one_back(img_size, img_size, n_grains, lshK_back, lshL_back, h, \
    //                                      init_A, init_B, init_L, init_kappa, dtime, lshr_back);

    //         one_back.encode_from_img(eta1_eta2_sim, dloss);

    //         auto start_back = std::chrono::high_resolution_clock::now();
    //         for (int j = 0; j < skip_step; ++j) {
    //             // auto start = std::chrono::high_resolution_clock::now();
    //             one_back.next();
    //             // auto stop = std::chrono::high_resolution_clock::now();
    //             // auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);
    //             // std::cout << "time of one forward: " << duration.count() << std::endl;
    //             // std::cout << "sim back step: " << j << std::endl;
    //         }
    //         auto stop_back = std::chrono::high_resolution_clock::now();
    //         auto duration_back = std::chrono::duration_cast<std::chrono::milliseconds>(stop_back - start_back);
    //         std::cout << "time of ts model backward: " << duration_back.count() << "ms in " << skip_step << "steps" << std::endl;

    //         valueType* tsm_back_grad = one_back.decode_derivative(); // L A B kappa

    //         if (debug_on) {
    //             std::cout << "success batch loss backward" << std::endl;
    //         }

    //         // init_L -= lr * (std::abs(tsm_back_grad[0])>(10/lr)?0:tsm_back_grad[0]);
    //         // init_A -= lr * (std::abs(tsm_back_grad[1])>(10/lr)?0:tsm_back_grad[1]);
    //         // init_B -= lr * (std::abs(tsm_back_grad[2])>(10/lr)?0:tsm_back_grad[2]);
    //         // init_kappa -= lr * (std::abs(tsm_back_grad[3])>(10/lr)?0:tsm_back_grad[3]);

    //         init_L -= lr * (tsm_back_grad[0]>(10/lr)?0:tsm_back_grad[0]);
    //         init_A -= lr * (tsm_back_grad[1]>(10/lr)?0:tsm_back_grad[1]);
    //         init_B -= lr * (tsm_back_grad[2]>(10/lr)?0:tsm_back_grad[2]);
    //         init_kappa -= lr * (tsm_back_grad[3]>(10/lr)?0:tsm_back_grad[3]);

    //         if (debug_on) {
    //             std::cout << "success opt1 opt2 step()" << std::endl;
    //         }

    //         int this_size = 1;
    //         valueType batch_loss = sum_mse_loss(eta1_eta2_sim, eta1_eta2_ref, img_size * img_size * n_grains);
    //         loss += batch_loss;
    //         if (true) {
    //             std::cout << "--------------loss-----------------" << std::endl;
    //             std::cout << "batch loss: " << batch_loss << std::endl;
    //             std::cout << "--------------grad-----------------" << std::endl;
    //             one_back.print_derivative();
    //             std::cout << "--------------param----------------" << std::endl;
    //             std::cout << init_L<< std::endl;
    //             std::cout << init_A << std::endl;
    //             std::cout << init_B << std::endl;
    //             std::cout << init_kappa << std::endl;
    //         }
    //         total_size += this_size;

    //         delete eta1_eta2_sim;
    //         delete tsm_back_grad;
    //     }
    // }
    
    printf("Have a nice day!\n");
    return 0;
}
