from dataset.getdataset_MSB import GetDataSetMSB
from dataset.getdataset_DP import GetDataSetDP
from dataset.getdataset_GC import GetDataSetGC
from simulator import Simulator
import torch
import time
from utils import NodeType
import matplotlib.pyplot as plt
import numpy as np
import math
import copy
import threading
import h5py
    
import os

model_name = "MGN"
model_set = ""
model_version = "-1000000"

dataset_name = "MSB"
device = torch.device(f'cuda:3')
if (dataset_name == "MSB"):  
    dataset_dir = "/strip_simple/"
elif (dataset_name == "DP"):  
    dataset_dir = "/meshgraphnet_data/deepmind_h5/deforming_plate/"
elif (dataset_name == "GC"):  
    dataset_dir = "/cavity_grasping_dataset/"

batch_size = 1  

config_dir =  f"config/config_{model_name}_{dataset_name}.yml"

simulator = Simulator(config_dir, device = device).to(device)

simulator.load_checkpoint(f"checkpoint/{dataset_name}-{model_name}{model_set}{model_version}.pth")
    
def test(model, dataloader):
    model.eval()
    loss_sum = []
    loss_cnt = []
    
    for i in range(150):
        loss_sum.append([0, 0, 0])
        loss_cnt.append(0)

    cnt = 0
    current_pos = 0.0
    current_stress = 0.0
    last_pos = 0.0
    trace_num = 0

    for batch_index, mydata in enumerate(dataloader):
        data_list = mydata
        for key in data_list.keys():
            data_list[key] = data_list[key].to(device)
            
        node_type = data_list["node_type"]
        refer_last_stress = copy.deepcopy(data_list["stress"])
        mask = (node_type==NodeType.NORMAL).reshape(-1)
        if (cnt > 0):
            data_list["world_pos"][mask, :] = current_pos[mask, :]
            data_list["last_pos"][mask,  :] = last_pos[mask, :]
            data_list["stress"][mask, :] = current_stress[mask, :]                
        else:
            current_pos = copy.deepcopy(data_list["world_pos"])
            last_pos = copy.deepcopy(data_list["world_pos"])
                
        predicted_results, _ = model(data_list, False)

        predicted_results = model.model._output_normalizer.inverse(predicted_results)
         
        predicted_results[:, 0:3] += current_pos 
        predicted_results[:, 3:] += current_stress
        predicted_results = predicted_results.detach()
        target_results = data_list["target"].detach()
        target_results[:, 3:] += refer_last_stress
        
        errors_pos = ((predicted_results - target_results)**2)[mask.reshape(-1),0:3]
        errors_pos = torch.sum(errors_pos, dim = 1)
        rmse_pos = torch.sqrt(torch.mean(errors_pos))
        loss_sum[cnt][0] += rmse_pos
   
        errors_stress = ((predicted_results - target_results)**2)[mask.reshape(-1),3:]
        rmse_stress = torch.sqrt(torch.mean(errors_stress, dim = 0))
        loss_sum[cnt][1] += rmse_stress[0]
        loss_sum[cnt][2] += rmse_stress[1]
            
        loss_cnt[cnt] += 1
        last_pos = current_pos
        current_pos = copy.deepcopy(predicted_results[:, :3])
        current_stress = copy.deepcopy(predicted_results[:, 3:])
        cnt += 1
        #print(cnt, ' ', data_list["final_flag"])
        if (data_list["final_flag"] == True):
            cnt = 0
            current_pos = 0
            current_stress = 0
            print(trace_num)
            trace_num += 1
            
    loss_sum = torch.tensor(loss_sum)
    loss_cnt = torch.tensor(loss_cnt)
    for pw in [1, 50, 75, 100, 999]:
        for dim in range(3):
            res = torch.sum(loss_sum[0:pw, dim]) / torch.sum(loss_cnt[0:pw])
            print(f"dim:{dim}, rollout-{pw}steps: error {res * 1000:.3f}")

if __name__ == '__main__':
    if (dataset_name == "MSB"):
        test_dataset = GetDataSetMSB(dataset_dir=dataset_dir, split='test', batch_size = 1)   
    elif (dataset_name == "GC"):
        test_dataset = GetDataSetGC(dataset_dir=dataset_dir, split='test', batch_size = 1)  
    elif (dataset_name == "DP"):
        test_dataset = GetDataSetDP(dataset_dir=dataset_dir, split='test', batch_size = 1)  

    test(simulator, test_dataset)
