import sys
import copy
import random
import json
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
plt.rcParams["animation.html"] = "jshtml"

EPS = 1e-7


def test(net, test_data, dt, T_domain):
    net.eval()
    num = test_data.shape[0]
    error_list = []
    t = 0
    step_idx = 1
    step = test_data.shape[-1]
    pre_data = test_data[..., 0:1]
    print('Domain_Error----------------')
    while T_domain - t > 0.00001:
        with torch.no_grad():
            pre_data = net(pre_data).detach()
        gth = test_data[..., step_idx:step_idx+1]
        error = torch.norm((pre_data - gth).reshape(num, -1), dim=1) / torch.norm((gth).reshape(num, -1), dim=1)
        t = t + dt
        step_idx = step_idx + 1
        error_list.append(error.mean().item())
        if abs(t - round(t)) <  0.00001:
            print(t, error_list[-1])
    
    print('Future_Error----------------')
    while step_idx <= step - 1:
        with torch.no_grad():
            pre_data = net(pre_data).detach()
        gth = test_data[..., step_idx:step_idx+1]
        error = torch.norm((pre_data - gth).reshape(num, -1), dim=1) / torch.norm((gth).reshape(num, -1), dim=1)
        t = t + dt
        step_idx = step_idx + 1
        error_list.append(error.mean().item())
        if abs(t - round(t)) <  0.00001:
            print(t, error_list[-1])
    print('Mean Error: ', np.array(error_list).mean())
    return np.array(error_list).mean()


def train(config, net):
    device = config.device
    data_path = config.data_path
    nu, f = config.nu, config.f

    data_dict_path = config.data_path+f'log/ns_train_nu_{nu}_f_{f}.txt'
    with open(data_dict_path, 'r') as file:
        file_content = file.read()
    train_data_dict = json.loads(file_content)
    train_ratio = train_data_dict['record_ratio']
    T_domain = train_data_dict['T']
    
    data_dict_path = config.data_path+f'log/ns_test_nu_{nu}_f_{f}.txt'
    with open(data_dict_path, 'r') as file:
        file_content = file.read()
    test_data_dict = json.loads(file_content)
    test_ratio = test_data_dict['record_ratio']
    
    dt = config.dt
    model_ratio = round(1 / dt)
    num_train = config.num_train 
    
    train_data = torch.load(data_path+f'dataset/ns_train_nu_{nu}_f_{f}').float()[:num_train, ..., ::int(train_ratio//model_ratio)]
    test_data = torch.load(data_path+f'dataset/ns_test_nu_{nu}_f_{f}').to(device).float()[..., ::int(test_ratio//model_ratio)]
    val_data = torch.load(data_path+f'dataset/ns_val_nu_{nu}_f_{f}').to(device).float()[..., ::int(train_ratio//model_ratio)]
    size, train_step = train_data.shape[1], train_data.shape[-1]
    num_step1, num_step2 = config.num_step1, config.num_step2
    print(train_data.shape, val_data.shape, test_data.shape)
    net = net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=config.num_iterations+1, max_lr=config.lr)
    err_record = 1e10
    
    for step in range(config.num_iterations+1):
        net.train()
        batch_size = 5

        input_data = torch.zeros(batch_size, size, size, 1, device=device)
        output_data = torch.zeros(batch_size, size, size, num_step2, device=device)

        n = random.randint(1, num_step1)

        max_valid_start_t_for_this_n = train_step - n - num_step2

        if max_valid_start_t_for_this_n < 0:
            print(f"Warning: For step {step}, selected jump n={n} is too large for trajectory length {train_step} and num_step2 ({num_step2}). No valid start time exists. Skipping this batch.")
            continue

        possible_start_times_for_this_n = list(range(0, max_valid_start_t_for_this_n + 1))

        if batch_size <= num_train:
            random_traj_indices = random.sample(range(num_train), batch_size)
        else:
             print(f"Warning: Requested batch_size ({batch_size}) is greater than total trajectories ({num_train}). Sampling with replacement.")
             random_traj_indices = random.choices(range(num_train), k=batch_size)

        sampled_start_times = random.choices(possible_start_times_for_this_n, k=batch_size)

        for k_batch in range(batch_size):
            traj_idx = random_traj_indices[k_batch]

            start_t = sampled_start_times[k_batch]

            input_data[k_batch] = train_data[traj_idx, :, :, start_t : start_t + 1].to(device)
            output_data[k_batch] = train_data[traj_idx, :, :, start_t + n : start_t + n + num_step2].to(device)


        
        pre_data = input_data
        with torch.no_grad():
            for _ in range(n-1):    
                pre_data = net(pre_data)
        
        pre_data = net(pre_data)
        for _ in range(num_step2-1):    
            pre_data = torch.concat([pre_data, net(pre_data[..., -1:])], dim=-1)
 
        loss = torch.sqrt(torch.mean(torch.square(pre_data - output_data)) + EPS)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

        if step % 100 == 0:
            print(step, 'Training_Error: ', loss.detach().item())
            train_error = torch.norm(pre_data.reshape(num_train, -1) - output_data.reshape(num_train, -1), dim=1) / torch.norm(output_data.reshape(num_train, -1), dim=1)
            print(train_error.mean().detach().item())
            print('Val_Error----------------')
            val_error = test(net, val_data, dt, T_domain)
            
            if val_error < err_record:
                err_record = val_error
                print('----------------------------MODEL_UPDATED-------------------------------')
                torch.save(net.state_dict(), f'model/{config.model_name}.pt')
                print('Test_Error----------------')
                test_error = test(net, test_data, dt, T_domain)
            sys.stdout.flush()


    print('----------------------------FINAL_RESULT-----------------------------')
    net.load_state_dict(torch.load(f'model/{config.model_name}.pt'))
    test(net, test_data, dt, T_domain)
    sys.stdout.flush()