import torch
from torch.autograd import Variable
import numpy as np
import random
import torch.backends.cudnn as cudnn
from sklearn.metrics import explained_variance_score, max_error
import os
import time

from config import *
from transformer.transformer import *
from data_gen import *
from utils import *
from dataloader import *
from data_gen.dataset_util import *
args = parser.parse_args()

FILE_ABS_PATH = os.path.abspath(__file__)
FILE_DIR_PATH = os.path.dirname(FILE_ABS_PATH)

# CUDA support
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')


# Net operator.
# There are two possible modes: "save" and "load".
def net_operator(net, system, mode):
    
    str_analytical = "analytical" if args.numerical else f"PINN_based_{args.PINN_based_prior_ratio}"
    SAVE_PATH = f"./parameter/{system}/{args.u0_str}"
    SAVE_NAME = f"{args.beta_min}_{args.beta_max}_{args.nu_min}_{args.nu_max}_{args.rho_min}_{args.rho_max}_" +\
        f"{args.epsilon_min}_{args.epsilon_max}_{args.theta_min}_{args.theta_max}_{str_analytical}.pt"
    
    print(SAVE_NAME)
    if not os.path.isdir(SAVE_PATH): os.makedirs(SAVE_PATH)
    
    if mode == "load":
        LOAD_PATH = SAVE_PATH + "/" + SAVE_NAME
        if os.path.isfile(LOAD_PATH):
            net.load_state_dict(torch.load(LOAD_PATH, map_location=device, weights_only=True))
            print("The net is loaded.")
        else:
            print("This is a new net.")
    elif mode == "save":
        torch.save(net.state_dict(), SAVE_PATH + "/" + SAVE_NAME)
    else:
        print("Net operator does not work.")
    
    return net

# Trained PDE.
specified_pde_list = []

# Set steps.
beta_step = args.beta_step
nu_step = args.nu_step
rho_step = args.rho_step
epsilon_step = args.epsilon_step
theta_step = args.theta_step

# According to steps, determine intervals.
beta_interval = int((args.beta_max - args.beta_min) // beta_step)
nu_interval = int((args.nu_max - args.nu_min) // nu_step)
rho_interval = int((args.rho_max - args.rho_min) // rho_step)
epsilon_interval = int((args.epsilon_max - args.epsilon_min) // epsilon_step)
theta_interval = int((args.theta_max - args.theta_min) // theta_step)

# Make a specified PDE list.
for i in range(beta_interval + 1):
    for j in range(nu_interval + 1):
        for k in range(rho_interval + 1):
            for l in range(epsilon_interval + 1):
                for m in range(theta_interval + 1):
            
                    beta = args.beta_min + i
                    nu = args.nu_min + j
                    rho = args.rho_min + k
                    epsilon = args.epsilon_min + l
                    theta = args.theta_min + m
                    
                    specified_pde_list.append((beta, nu, rho, epsilon, theta))

# Tag the beta, nu, rho, epsilon, theta pair to use as a prior.
sample_size = int(len(specified_pde_list) * args.PINN_based_prior_ratio / 100)
PINN_based_PDE_list = random.sample(specified_pde_list, sample_size)

def Gaussian_noise(train_data):
    mean_value = torch.mean(train_data[:, 2])
    std_dev = torch.abs(mean_value * (args.noise_rate / 100))
    noise = torch.normal(mean=0.0, std=std_dev, size=train_data[:, 2].size())
    train_data[:, 2] += noise
    return train_data

def salt_and_pepper_noise(train_data):
    prob = args.noise_rate / 100
    mask = torch.rand(train_data[:, 2].size())
    salt_mask = mask < (prob / 2)
    train_data[salt_mask, 2] = train_data[:, 2].max()
    pepper_mask = mask > (1 - prob / 2)
    train_data[pepper_mask, 2] = train_data[:, 2].min()
    return train_data

def uniform_noise(train_data):
    noise_range = args.noise_rate / 100
    uniform_noise = torch.empty(train_data[:, 2].size()).uniform_(-noise_range, noise_range)
    train_data[:, 2] += uniform_noise
    return train_data

def main():
    
    args = get_config()
    random_seed = args.seed
    
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    np.random.seed(random_seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(random_seed)
    

    print("=============[Deivce Info]==============")
    print("- Use Device :", device)
    print("- Available cuda devices :", torch.cuda.device_count())
    print("- Current cuda device :", torch.cuda.current_device())
    print("- Name of cuda device :", torch.cuda.get_device_name(device))
    print("- Ratio of PINN based prior :", f"{args.PINN_based_prior_ratio}%")
    print("========================================\n")
    
    epoch = args.epoch
    transformer = TransformerModel(n_out = 1, 
                                   ninp = 16, 
                                   nhead = 1,
                                   nhid = 32, 
                                   nlayers = 3,
                                   dropout=0.2, 
                                   decoder=None).to(device)
    model_size = param_number(transformer)
    initial_condition = args.u0_str

    print("=============[Train Info]===============")
    print(f"- Initial condition : {initial_condition}")
    print(f"- Model size : {model_size}")
    print("========================================\n")
    
    optimizer = torch.optim.Adam(transformer.parameters(), lr=10e-5)
    
    # Randomly choose PDE to train.
    assert len(specified_pde_list) <= epoch, "There are too many specified pdes."
    beta_range = (args.beta_min * 1.0, args.beta_step * 1.0, args.beta_max * 1.0)
    nu_range = (args.nu_min * 1.0, args.nu_step * 1.0, args.nu_max * 1.0)
    rho_range = (args.rho_min * 1.0, args.rho_step * 1.0, args.rho_max * 1.0)
    epsilon_range = (args.epsilon_min * 1.0, args.epsilon_step * 1.0, args.epsilon_max * 1.0)
    theta_range = (args.theta_min * 1.0, args.theta_step * 1.0, args.theta_max * 1.0)
    
    randomly_chosen_pde = []
    randomly_chosen_pde = random_parameter(beta_range, nu_range, rho_range, epsilon_range, theta_range, epoch - len(specified_pde_list))
    pde_list = specified_pde_list + randomly_chosen_pde
    random.shuffle(pde_list)
    
    # Current time.
    current_time = time.time()
    
    # Train.
    for EPOCH in range(epoch):
        
        transformer.train()
        
        # Measure start time.
        start_time = time.time()
        
        # Specify beta, nu, rho value. 
        beta, nu, rho, epsilon, theta = pde_list[EPOCH]
        
        # Check the prior exists or not. Otherwise, generate it.
        system = system_determinator(beta, nu, rho, epsilon, theta)
        
        # Setting for x% analytical and y% prior.
        str_analytical = "PINN_based" if pde_list[EPOCH] in PINN_based_PDE_list else "analytical"
    
        
        save_path_root = f"./data_gen/dataset/{system}/{args.u0_str}"
        file_name = f"{beta}_{nu}_{rho}_{epsilon}_{theta}_{str_analytical}.csv"
        full_save_path_dir = save_path_root + DATASET_DIRECTORIES[0]
        full_save_path = f"{full_save_path_dir}/{file_name}"
        if not os.path.isfile(full_save_path):
            raise ValueError("There is no file.")
            
        
        # Bring a train data.
        train_collocation_df = pd.read_csv(f"{save_path_root}{DATASET_DIRECTORIES[3]}/{file_name}")
        train_boundary_df = pd.read_csv(f"{save_path_root}{DATASET_DIRECTORIES[1]}/{file_name}")
        train_initial_df =  pd.read_csv(f"{save_path_root}{DATASET_DIRECTORIES[2]}/{file_name}")
        
        train_collocation_loader = cdr_DATA(train_collocation_df)
        train_boundary_loader = cdr_DATA(train_boundary_df)
        train_initial_loader = cdr_DATA(train_initial_df)
        
        train_data = torch.cat((train_collocation_loader.get_all(), train_boundary_loader.get_all(), train_initial_loader.get_all()), dim=0)
        
        # Shuffle train data once to avoid collocation, initial, and boundary separation.
        mixed_train_indices = torch.randperm(train_data.size(0))
        train_data = train_data[mixed_train_indices]

        # Inject noise. (salt and pepper noise is given as an instance)
        # if args.noise:
        #     train_data = salt_and_pepper_noise(train_data)

        # Calculate the loss with an output of model.
        input = train_data.unsqueeze(0).to(device)
        single_eval_pos = int(0.3*train_data.size(0))
        train_solution = train_data[single_eval_pos:, 2:].to(device)
        output = transformer(input, single_eval_pos)
        loss = torch.mean((train_solution - output.squeeze(0))**2)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Measure end time.
        end_time = time.time()
        
        if (EPOCH + 1)%1000 == 0:
            with torch.autograd.no_grad():
                transformer.eval()

                # Bring a test data.
                test_collocation_df = pd.read_csv(f"{save_path_root}{DATASET_DIRECTORIES[5]}/{file_name}")
                test_df = pd.read_csv(f"{save_path_root}{DATASET_DIRECTORIES[7]}/{file_name}")
                
                test_collocation_loader = cdr_DATA(test_collocation_df)
                test_loader = cdr_DATA(test_df)
                
                test_given_data = test_collocation_loader.get_all()
                test_data = test_loader.get_all()

                # Estimate the test loss.
                result = 0
                for i in range(10):
                    input = torch.cat((test_given_data, test_data[100*i:100*(i+1), :]), dim=0).unsqueeze(0).to(device)
                    output = transformer(input, test_given_data.size(0))
                    if not torch.is_tensor(result): result=output.squeeze(0)
                    else: result = torch.cat((result, output.squeeze(0)), dim=0)
                
                test_solution = test_data[:, 2:].to(device)

                L2_error_norm = torch.linalg.norm(result-test_solution, 2, dim = 0)
                L2_true_norm = torch.linalg.norm(test_solution, 2, dim = 0)

                L2_absolute_error = torch.mean(torch.abs(result-test_solution))
                L2_relative_error = L2_error_norm / L2_true_norm
                
                result = result.cpu()
                test_solution = test_solution.cpu()
                
                Max_err = max_error(test_solution, result)
                Ex_var_score = explained_variance_score(test_solution, result)
                
                
                print('=================================================================')  
                print('EPOCH :', EPOCH + 1)
                print("(beta, nu, rho, epsilon, theta):", beta, nu, rho, epsilon, theta)
                print('Test Loss :', loss.item())
                print('L2_abs_err :', L2_absolute_error.item())
                print('L2_rel_err :', L2_relative_error.item())
                print('Max_error :', Max_err)
                print('Variance_score :', Ex_var_score)
                print('Time consumed (s):', round(end_time - current_time, 2))                
                print('=================================================================')  
        
        if ((EPOCH + 1) % 1000 == 0):
            net_operator(transformer, system, "save")
        
    print('Epoch number :', EPOCH+1)
    print('=================================================================')

if __name__ == "__main__":
    main()


