from simulator import Simulator
import torch
import time
from utils import NodeType
import matplotlib.pyplot as plt
import numpy as np
import math
import copy
    
import os

def rollout(model:Simulator, dataloader, device):
    model.eval()
    loss_sum = []
    loss_cnt = []
    
    for i in range(100):
        loss_sum.append([0, 0])
        loss_cnt.append(0)

    cnt = 0
    current_pos = 0.0
    current_stress = 0.0
    last_pos = 0.0
    valid_cnt = 0

    trace_num, valid_pos, valid_stress = 0.0, 0.0, np.zeros(2)

    for batch_index, mydata in enumerate(dataloader):
        data_list = mydata
        for key in data_list.keys():
            data_list[key] = data_list[key].to(device)
            
        node_type = data_list["node_type"]
        refer_last_stress = copy.deepcopy(data_list["stress"])
        mask = (node_type==NodeType.NORMAL).reshape(-1)
        if (cnt > 0):
            data_list["world_pos"][mask, :] = current_pos[mask, :]
            data_list["last_pos"][mask,  :] = last_pos[mask, :]
            data_list["stress"][mask, :] = current_stress[mask, :]                
        else:
            current_pos = copy.deepcopy(data_list["world_pos"])
            last_pos = copy.deepcopy(data_list["world_pos"])
                
        predicted_results, _ = model(data_list, False)

        predicted_results = model.model._output_normalizer.inverse(predicted_results)
         
        predicted_results[:, 0:3] += current_pos 
        predicted_results[:, 3:] += current_stress
        predicted_results = predicted_results.detach()
        target_results = data_list["target"].detach()
        target_results[:, 3:] += refer_last_stress
        target_results = data_list["target"].detach()
        
        errors_pos = ((predicted_results - target_results)**2)[mask.reshape(-1),0:3]
        errors_pos = torch.sum(errors_pos, dim = 1)
        rmse_pos = torch.sqrt(torch.mean(errors_pos))
        loss_sum[cnt][0] += rmse_pos
   
        errors_stress = ((predicted_results - target_results)**2)[mask.reshape(-1),3:]
        rmse_stress = torch.sqrt(torch.mean(errors_stress, dim = 0))
        loss_sum[cnt][1] += rmse_stress[0]
            
        valid_pos += rmse_pos.item()
        valid_stress += rmse_stress.item()
        valid_cnt += 1
            
        loss_cnt[cnt] += 1
        last_pos = current_pos
        current_pos = copy.deepcopy(predicted_results[:, :3])
        current_stress = copy.deepcopy(predicted_results[:, 3:])
        cnt += 1

        if (data_list["final_flag"] == True):
            cnt = 0
            current_pos = 0
            current_stress = 0
            trace_num += 1
            
    return valid_pos / valid_cnt, valid_stress / valid_cnt
