#include "grain_growth.hpp"
#include "cstring"
#include "iostream"
#include <fstream>
#include <sstream>
#include <string>
#include <algorithm>
#include <iterator>
#include <cassert>
#include <iomanip>

GrainGrowthDataset::GrainGrowthDataset(char* data_path, int _start_skip, int _skip_step) {

    // FILE *inp = fopen(data_path, "r");
    // fscanf(inp, "%u,%u,%u,%u", &Nx, &Ny, &n_grains, &n_step);
    // fclose(inp);

    all_data = read_data_file(data_path, Nx, Ny, n_grains, n_step);

    // valueType* test_init_file = read_init_state(data_path, Nx, Ny, n_grains, n_step);

    // test_init_file = read_init_state_1("../data_128", Nx, Ny, n_grains);

    // assert(1==0);

    // int _start_skip = 9;
    // start_skip = torch::from_blob(&_start_skip, {1}, torch::kInt);
    start_skip = _start_skip;
    // skip_step = torch::from_blob(&_skip_step, {1}, torch::kInt);
    skip_step = _skip_step;
    int _cnt = n_step - _start_skip * 2 - _skip_step;
    // cnt = torch::from_blob(&_cnt, {1}, torch::kInt);
    cnt = _cnt;

    // if (debug_on) {
    //     std::cout << "cv shape: " << cv.sizes() << std::endl;
    //     std::cout << "ci shape: " << ci.sizes() << std::endl;
    //     std::cout << "eta shape: " << eta.sizes() << std::endl;
    //     std::cout << "video shape: " << video.sizes() << std::endl;
    //     std::cout << "start skip tensor: " << start_skip << std::endl;
    //     std::cout << "skip step tensor: " << skip_step << std::endl;
    //     std::cout << "cnt tensor: " << cnt << std::endl;
    //     // std::cout << "cnt item value: " << cnt.item<int>() << std::endl;
    // }

}


ReturnItem GrainGrowthDataset::get_item(size_t index)
{
    // valueType* mtx = this->all_data;
    // for (uint pg = 0; pg < n_grains; ++pg) {
    //     int n = 0;
    //     if (n==0) {
    //         valueType sum = 0;
    //         valueType *eta = mtx + pg*Nx*Ny;
    //         for (uint i = 0; i < 128*128; ++i) {
    //             sum += eta[i];
    //             // if (pg==1) {
    //             //     std::cout << "eta2_" << i << ": " << eta[i] << std::endl;
    //             // }
    //         }
    //         std::cout << "sum of mtx first frame eta_" << pg+1 << ": " << sum << std::endl;
    //     }
    //     n = 1500;
    //     if (n==1500) {
    //         valueType sum = 0;
    //         valueType *eta = mtx + n*n_grains*Nx*Ny + pg*Nx*Ny;
    //         for (uint i = 0; i < 128*128; ++i) {
    //             sum += eta[i];
    //             // if (pg==1) {
    //             //     std::cout << "eta2_" << i << ": " << eta[i] << std::endl;
    //             // }
    //         }
    //         std::cout << "sum of mtx 1500 frame eta_" << pg+1 << ": " << sum << std::endl;
    //         // valueType sum = 0;
    //         // valueType *eta = mtx + n*n_grains*Nx*Ny + pg*Nx*Ny;
    //         // for (uint i = 0; i < Nx*Ny; ++i) {
    //         //     if (eta[i] == 0.0) {
    //         //         sum++;
    //         //     }
    //         // }
    //         // std::cout << "Zero elements in eta_" << pg+1 << ": " << sum << std::endl;
    //         sum = 0;
    //         eta = mtx + pg*Nx*Ny;
    //         for (uint i = 0; i < 128*128; ++i) {
    //             sum += eta[i];
    //         }
    //         std::cout << "sum of mtx after1500 first frame eta_" << pg+1 << ": " << sum << std::endl;
    //     }
    // }





    if (index >= cnt) 
        return ReturnItem(ReturnData(nullptr, nullptr, nullptr), ReturnLabel(nullptr, nullptr, nullptr));
    // prepare data
    valueType* eta1_eta2 = all_data + index*n_grains*Nx*Ny;
    valueType* eta1 = eta1_eta2;
    valueType* eta2 = eta1 + Nx*Ny;

    // prepare label
    size_t index_ref = index + skip_step;
    valueType* eta1_eta2_ref = all_data + index_ref*n_grains*Nx*Ny;
    valueType* eta1_ref = eta1_eta2_ref;
    valueType* eta2_ref = eta1_ref + Nx*Ny;
    
    ReturnData return_data(eta1, eta2, eta1_eta2);
    ReturnLabel return_label(eta1_ref, eta2_ref, eta1_eta2_ref);

    ReturnItem return_item(return_data, return_label);
    return return_item;
}


void calculate_mse_loss(valueType* eta, valueType* eta_ref, valueType* dloss, uint size) {
    for (uint i = 0; i < size; ++i) {
        dloss[i] = 2 * (eta[i] - eta_ref[i]);
    }
}

valueType sum_mse_loss(valueType* eta, valueType* eta_ref, uint size) {
    valueType sum = 0;
    for (uint i = 0; i < size; ++i) {
        sum += (eta[i] - eta_ref[i]) * (eta[i] - eta_ref[i]);
    }
    return sum;
}

valueType sum_mtx(valueType* eta, uint size) {
    valueType sum = 0;
    for (uint i = 0; i < size; ++i) {
        sum += eta[i];
    }
    return sum;
}

valueType *read_data_file(const char *filename, uint &Nx, uint &Ny, uint &n_grains, uint &n_step)
{
    FILE *inp = fopen(filename, "r");
    fscanf(inp, "%u,%u,%u,%u", &Nx, &Ny, &n_grains, &n_step);
    // std::cout << "header: " << Nx << ", " << Ny << ", " << n_grains << ", " << n_step << std::endl;
    valueType *mtx = new valueType[Nx * Ny * n_grains * n_step];
    // std::cout << "inside dataset init" << std::endl;
    for (uint n = 0; n < n_step; ++n) {    
        for (uint pg = 0; pg < n_grains; ++pg) {
            for (uint x = 0; x < Nx; ++x) {
                fscanf(inp, "%lf", mtx + n*n_grains*Nx*Ny + pg*Nx*Ny + x*Ny);
                for (uint y = 1; y < Ny; ++y)
                    fscanf(inp, ",%lf", mtx + n*n_grains*Nx*Ny + pg*Nx*Ny + x*Ny + y);
            }
            // if (n==0) {
            //     valueType sum = 0;
            //     valueType *eta = mtx + pg*Nx*Ny;
            //     std::cout << "address of mtx first frame eta_" << pg+1 << ": " << eta << std::endl;
            //     for (uint i = 0; i < 128*128; ++i) {
            //         sum += eta[i];
            //         // if (pg==1) {
            //         //     std::cout << "eta2_" << i << ": " << eta[i] << std::endl;
            //         // }
            //     }
            //     std::cout << "sum of mtx first frame eta_" << pg+1 << ": " << sum << std::endl;
            //     std::cout << std::endl;
            // }
            // if (n==1500) {
            //     valueType sum = 0;
            //     valueType *eta = mtx + n*n_grains*Nx*Ny + pg*Nx*Ny;
            //     std::cout << "address of mtx 1500 frame eta_" << pg+1 << ": " << eta << std::endl;
            //     for (uint i = 0; i < 128*128; ++i) {
            //         sum += eta[i];
            //         // if (pg==1) {
            //         //     std::cout << "eta2_" << i << ": " << eta[i] << std::endl;
            //         // }
            //     }
            //     std::cout << "sum of mtx 1500 frame eta_" << pg+1 << ": " << sum << std::endl;
            //     std::cout << std::endl;
            //     // valueType sum = 0;
            //     // valueType *eta = mtx + n*n_grains*Nx*Ny + pg*Nx*Ny;
            //     // for (uint i = 0; i < Nx*Ny; ++i) {
            //     //     if (eta[i] == 0.0) {
            //     //         sum++;
            //     //     }
            //     // }
            //     // std::cout << "Zero elements in eta_" << pg+1 << ": " << sum << std::endl;
            //     // sum = 0;
            //     // eta = mtx + pg*Nx*Ny;
            //     // for (uint i = 0; i < 128*128; ++i) {
            //     //     sum += eta[i];
            //     // }
            //     // std::cout << "sum of mtx after1500 first frame eta_" << pg+1 << ": " << sum << std::endl;
            // }
        }
        // if (n==0)
        //     std::cout << "sum of mtx first frame: " << sum_mtx(mtx, 128*128*2) << std::endl;
        // if (n==1500)
        //     assert(1==0);
    }
    // std::cout << "outside dataset init" << std::endl;
    fclose(inp);

    return mtx;
}

valueType* read_init_state_1(const char* filename, uint &Nx, uint &Ny, uint &n_grains) {
  FILE *inp = fopen(filename, "r");
  fscanf(inp, "%u,%u,%u", &Nx, &Ny, &n_grains);
  std::cout << "header: " << Nx << ", " << Ny << ", " << n_grains << std::endl;
  valueType* mtx = new valueType[Nx*Ny*n_grains];
  for (uint pg = 0; pg < n_grains; ++ pg) {
    for (uint x = 0; x < Nx; ++ x) {
      fscanf(inp, "%lf", mtx + pg*Nx*Ny + x*Ny);
      for (uint y = 1; y < Ny; ++ y)
        fscanf(inp, ",%lf", mtx + pg*Nx*Ny + x*Ny + y);
    }
  }
  
//   for (uint pg = 0; pg < n_grains; ++pg) {
//         valueType sum = 0;
//         valueType *eta = mtx + pg*Nx*Ny;
//         for (uint i = 0; i < Nx*Ny; ++i) {
//             sum += eta[i];
//             std::cout << "read eta_" << pg+1 << "_" << i << ": " << eta[i] << std::endl;
//         }
//         std::cout << "sum of mtx first frame eta_" << pg+1 << ": " << sum << std::endl;
//     }
  fclose(inp);
  return mtx;
}

