import sys
import copy
import random
import json
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
plt.rcParams["animation.html"] = "jshtml"

EPS = 1e-7


def test(net, test_data, dt, T_domain):
    net.eval()
    num = test_data.shape[0]
    
    total_steps = test_data.shape[-1]
    

    pre_data_initial = test_data[..., 0:1]


    with torch.no_grad():
        predicted_trajectory = net(pre_data_initial, n_step=total_steps - 1)
    

    predicted_steps = predicted_trajectory


    gth_steps = test_data[..., 1:]
    

    assert predicted_steps.shape == gth_steps.shape, "Predicted and ground truth shapes do not match."


    error_list = []
    
    print('Domain_Error----------------')

    domain_steps = int(T_domain / dt)
    for i in range(domain_steps):
        pred_t = predicted_steps[..., i:i+1]
        gth_t = gth_steps[..., i:i+1]
        error = torch.norm((pred_t - gth_t).reshape(num, -1), dim=1) / torch.norm((gth_t).reshape(num, -1), dim=1)
        error_list.append(error.mean().item())
        t_val = (i + 1) * dt
        if abs(t_val - round(t_val)) < 0.00001:
            print(t_val, error_list[-1])

    print('Future_Error----------------')

    for i in range(domain_steps, total_steps - 1):
        pred_t = predicted_steps[..., i:i+1]
        gth_t = gth_steps[..., i:i+1]
        error = torch.norm((pred_t - gth_t).reshape(num, -1), dim=1) / torch.norm((gth_t).reshape(num, -1), dim=1)
        error_list.append(error.mean().item())
        t_val = (i + 1) * dt
        if abs(t_val - round(t_val)) < 0.00001:
            print(t_val, error_list[-1])

    print('Mean Error: ', np.array(error_list).mean())
    return np.array(error_list).mean()


def train(config, net):
    print("开始训练！")
    device = config.device
    data_path = config.data_path
    nu, f = config.nu, config.f

    data_dict_path = config.data_path+f'log/ns_train_nu_{nu}_f_{f}.txt'
    with open(data_dict_path, 'r') as file:
        file_content = file.read()
    train_data_dict = json.loads(file_content)
    train_ratio = train_data_dict['record_ratio']
    T_domain = train_data_dict['T']
    
    data_dict_path = config.data_path+f'log/ns_test_nu_{nu}_f_{f}.txt'
    with open(data_dict_path, 'r') as file:
        file_content = file.read()
    test_data_dict = json.loads(file_content)
    test_ratio = test_data_dict['record_ratio']
    print("读入数据成功！")
    dt = config.dt
    model_ratio = round(1 / dt)
    num_train = config.num_train 
    
    train_data = torch.load(data_path+f'dataset/ns_train_nu_{nu}_f_{f}').float()[:num_train, ...]
    
    test_data = torch.load(data_path+f'dataset/ns_test_nu_{nu}_f_{f}').to(device).float()[...,:3001][..., ::int(test_ratio//model_ratio)]
    val_data = torch.load(data_path+f'dataset/ns_val_nu_{nu}_f_{f}').to(device).float()[..., ::int(train_ratio//model_ratio)]
    size, train_step = train_data.shape[1], train_data.shape[-1]
    num_step1, num_step2 = config.num_step1, config.num_step2
    print(train_data.shape, val_data.shape, test_data.shape)
    net = net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=config.num_iterations+1, max_lr=config.lr)
    err_record = 1e10
    
    for step in range(config.num_iterations+1):
        net.train()
        batch_size = 5

        input_data = torch.zeros(batch_size, size, size, 1, device=device)
        output_data = torch.zeros(batch_size, size, size, num_step2, device=device)

        n = random.randint(1, num_step1)
        max_valid_start_t_for_this_n = train_step - n *int(train_ratio//model_ratio)- num_step2*int(train_ratio//model_ratio)
        if max_valid_start_t_for_this_n < 0:
            print(f"Warning: For step {step}, selected jump n={n} is too large for trajectory length {train_step} and num_step2 ({num_step2}). No valid start time exists. Skipping this batch.")
            continue
        possible_start_times_for_this_n = list(range(0, max_valid_start_t_for_this_n + 1))
        if batch_size <= num_train:
            random_traj_indices = random.sample(range(num_train), batch_size)
        else:
            print(f"Warning: Requested batch_size ({batch_size}) is greater than total trajectories ({num_train}). Sampling with replacement.")
            random_traj_indices = random.choices(range(num_train), k=batch_size)
        sampled_start_times = random.choices(possible_start_times_for_this_n, k=batch_size)

        for k_batch in range(batch_size):
            traj_idx = random_traj_indices[k_batch]
            start_t = sampled_start_times[k_batch]
            input_data[k_batch] = train_data[traj_idx, :, :, start_t  :  start_t +  int(train_ratio//model_ratio)][..., ::int(test_ratio//model_ratio)].to(device)
            output_data[k_batch] = train_data[traj_idx, :, :,  start_t + n  *int(train_ratio//model_ratio): start_t +  (n + num_step2)*int(train_ratio//model_ratio)][..., ::int(test_ratio//model_ratio)].to(device)


        pre_data = input_data
        with torch.no_grad():
            if n > 1:

                predicted_jump_trajectory = net(pre_data, n_step=n - 1)

                pre_data_n = predicted_jump_trajectory[..., -1:]
            else: 
                pre_data_n = pre_data
        

        predicted_output = net(pre_data_n, n_step=num_step2)


        assert predicted_output.shape == output_data.shape, "Shape mismatch between prediction and ground truth!"
        
        loss = torch.sqrt(torch.mean(torch.square(predicted_output - output_data)) + EPS)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

        if step % 100 == 0:
            print(step, 'Training_Error: ', loss.detach().item())
            

            print('Val_Error----------------')
            val_error = test(net, val_data, dt, T_domain)
            
            if val_error < err_record:
                err_record = val_error
                print('----------------------------MODEL_UPDATED-------------------------------')
                torch.save(net.state_dict(), f'model/{config.model_name}.pt')
                print('Test_Error----------------')
                test_error = test(net, test_data, dt, T_domain)
            sys.stdout.flush()

    print('----------------------------FINAL_RESULT-----------------------------')
    net.load_state_dict(torch.load(f'model/{config.model_name}.pt'))
    test_error = test(net, test_data, dt, T_domain)
    sys.stdout.flush()
    

    print('----------------------------Train_SingleStep_Error-----------------------------')
    net.eval()
    with torch.no_grad():
        T_steps = train_data.shape[-1]  
        total_error = 0.0
        count = 0
        for t in range(T_steps - 1):

            input_t = train_data[..., t:t+1].to(device)
            gt_t1   = train_data[..., t+1:t+2].to(device)
            pred_t1 = net(input_t)

            error_t = torch.norm((pred_t1 - gt_t1).reshape(train_data.shape[0], -1), dim=1) / \
                      torch.norm(gt_t1.reshape(train_data.shape[0], -1), dim=1)
            total_error += error_t
            count += 1

        train_error = (total_error / count).mean().item()
        print('Train single-step mean relative L2 error:', train_error)

    return val_error, test_error, train_error
