import torch
import math
import torch.nn as nn
import numpy as np

import time
import os
from torch.utils.data.dataset import Dataset

from  SINO import SINOmodel

os.environ['CUDA_VISIBLE_DEVICES'] = '4'
torch.set_default_dtype(torch.float32)

torch.manual_seed(66)
np.random.seed(66)


class SINORoll(nn.Module):
    def __init__(self, dt=1e-2, k_num=6, channel=32, step=1, effective_step=None):
        super(SINORoll, self).__init__()
        self.step = step
        self.effective_step = effective_step
        self.cell = SINOmodel(dt=dt, k_num=k_num, channel=channel)

    def forward(self, x):
        outputs = []
        for _ in range(self.step):
            x = self.cell(x)
            outputs.append(x)
        return torch.stack(outputs, dim=1)

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}






def permute_tensor(tensor):  

    tensor = tensor.permute(1, 0, 2, 3, 4)  

    return tensor


def calculate_rmse_iterative(predictions, actuals):  

    rmse_values = []  
    for i in range(len(predictions)):  
        lf = nn.MSELoss()

        squared_diff = lf(predictions[i] , actuals[i]).item()  

        rmse = np.sqrt(squared_diff) 

        rmse_values.append(rmse) 

    return rmse_values
 

def correlation(u,truth):
    u = u.reshape(1,-1)
    truth = truth.reshape(1,-1)
    u_truth = torch.cat(tuple([u,truth]),dim=0)
    coef = np.corrcoef(u_truth)[0][1]
    return coef.item()
def cal_cur_time_corre(u,truth):
    coef_list = []
    for i in range(len(u)):

        if i % 100 == 0:
            print(i)
        cur_truth = truth[i]
        cur_u = u[i]

        cur_coef = correlation(cur_u,cur_truth)
        coef_list.append(cur_coef)
    return coef_list



if __name__ == '__main__':

    time_steps =3000 


    data = torch.load("./data/ns_test_nu_1e-05_f_kf")

    
    
    UVT =data  .float().permute(0,3,1,2).unsqueeze(2)[:,: ]

    time_batch_size = time_steps
    steps = time_batch_size + 1
    effective_step = list(range(0, steps))
    Ndt = 1
    k_num = 6
    c=32
    time_steps = time_steps//Ndt
    modelshortv3  =SINORoll( dt = 0.005,step = time_steps,k_num=k_num , channel=c ,  effective_step=list(range(0, steps))).cuda()

    zz = get_parameter_number(modelshortv3 )
    print(zz)
 
    checkpoint = torch.load('./net-seed-0-modelv1-num_step1-4-num_step2-8-lr-0.01-channel-32-k_num-6-dt-0.005-step_forward-rk4-3_26_18_13_5.pt', map_location='cuda:0')
    modelshortv3.cell.load_state_dict(checkpoint )

    truth_clean = UVT[:,::Ndt]
    modelshortv3.init_state = truth_clean[:,0,:,:,:].cuda()

    

    start_time = time.time() 
    with torch.no_grad():  
        output  = modelshortv3(truth_clean[:,0,:,:,:].cuda())
    end_time = time.time()  
    print(end_time-start_time)



    error = torch.norm((output[ :,:3000].cpu()-truth_clean [ :,1:3001]).reshape(5, -1), dim=1) / torch.norm((truth_clean [ :, 1:3001]).reshape(5, -1), dim=1)
    eq = error.mean().item()
    print(eq)


