import torch
import numpy as np
import random
import torch.backends.cudnn as cudnn
import statistics
from sklearn.metrics import explained_variance_score, max_error
import os

from config import *
from transformer.transformer import *
from data_gen import *
from utils import *
from train import *
from data_gen.dataset_util import *
args = parser.parse_args()

FILE_ABS_PATH = os.path.abspath(__file__)
FILE_DIR_PATH = os.path.dirname(FILE_ABS_PATH)

# CUDA support
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Trained PDE.
test_pde_list = []

# Set steps.
beta_step = args.beta_step
nu_step = args.nu_step
rho_step = args.rho_step
epsilon_step = args.epsilon_step
theta_step = args.theta_step

# According to steps, determine intervals.
beta_interval = int((args.beta_max - args.beta_min) // beta_step)
nu_interval = int((args.nu_max - args.nu_min) // nu_step)
rho_interval = int((args.rho_max - args.rho_min) // rho_step)
epsilon_interval = int((args.epsilon_max - args.epsilon_min) // epsilon_step)
theta_interval = int((args.theta_max - args.theta_min) // theta_step)

# Make a specified PDE list.
for i in range(beta_interval + 1):
    for j in range(nu_interval + 1):
        for k in range(rho_interval + 1):
            for l in range(epsilon_interval + 1):
                for m in range(theta_interval + 1):
            
                    beta = args.beta_min + i
                    nu = args.nu_min + j
                    rho = args.rho_min + k
                    epsilon = args.epsilon_min + l
                    theta = args.theta_min + m
                    
                    test_pde_list.append((beta, nu, rho, epsilon, theta))
def main():
    
    args = get_config()
    random_seed = args.seed
    
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    np.random.seed(random_seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(random_seed)
    

    print("=============[Deivce Info]==============")
    print("- Use Device :", device)
    print("- Available cuda devices :", torch.cuda.device_count())
    print("- Current cuda device :", torch.cuda.current_device())
    print("- Name of cuda device :", torch.cuda.get_device_name(device))
    print("- Ratio of PINN based prior :", f"{args.PINN_based_prior_ratio}%")
    print("========================================\n")
    
    transformer = TransformerModel(n_out = 1, 
                                   ninp = 16, 
                                   nhead = 1,
                                   nhid = 32, 
                                   nlayers = 3,
                                   dropout=0.2, 
                                   decoder=None).to(device)

    model_size = sum(p.numel() for p in transformer.parameters())
    initial_condition = args.u0_str

    print("=============[Test Info]===============")
    print(f"- Initial condition : {initial_condition}")
    print(f"- Model size : {model_size}")
    print("========================================\n")

    # Statistics
    beta_list = []
    nu_list = []
    rho_list = []
    epsilon_list = []
    theta_list = []
    test_loss_list = []
    AE_list = []
    RE_list = []
    ME_list = []
    VS_list = []
    
        
    # Test.
    for item in test_pde_list:
        
        # Specify beta, nu, rho value. 
        beta, nu, rho, epsilon, theta = item
        
        # Bring a train and test data.
        system = system_determinator(beta, nu, rho, epsilon, theta)
        str_analytical = "analytical"
        save_path_root = f"./data_gen/dataset/{system}/{args.u0_str}"
        file_name = f"{beta}_{nu}_{rho}_{epsilon}_{theta}_{str_analytical}.csv"
        
        full_df = pd.read_csv(f"{save_path_root}{DATASET_DIRECTORIES[0]}/{file_name}")
        test_collocation_df = pd.read_csv(f"{save_path_root}{DATASET_DIRECTORIES[5]}/{file_name}")\
            if not args.extrapolation else pd.read_csv(f"{save_path_root}{DATASET_DIRECTORIES[6]}/{file_name}")
        test_df = pd.read_csv(f"{save_path_root}{DATASET_DIRECTORIES[7]}/{file_name}")\
            if not args.extrapolation else pd.read_csv(f"{save_path_root}{DATASET_DIRECTORIES[8]}/{file_name}")
        
        full_loader = cdr_DATA(full_df)
        test_collocation_loader = cdr_DATA(test_collocation_df)
        test_loader = cdr_DATA(test_df)
        
        # Query the data.
        test_given = test_collocation_loader.get_all()
        test_data = test_loader.get_all()
        
        # Shuffle test data
        mixed_test_indices = torch.randperm(test_given.size(0))
        test_given = test_given[mixed_test_indices]

        # Load model.
        transformer.eval()
        transformer = net_operator(transformer, system, "load")

        # Calculate the loss with an output of model.
        result = 0
        test_solution = 0
        shuffled_solution = 0
        
        # Interpolation case.
        if not args.extrapolation:
            for i in range(10):
                input = torch.cat((test_given, test_data[100*i:100*(i+1), :]), dim=0).unsqueeze(0).to(device)
                output = transformer(input, test_given.size(0))
                if not torch.is_tensor(result): result=output.squeeze(0)
                else: result = torch.cat((result, output.squeeze(0)), dim=0)
            test_solution = test_data[:, 2:].to(device)
        
        # Extrapolation case.
        else:
            updated_train_data = test_given.to(device)
            NUM_SECTION = 5
            INTERVAL = 0.4 / NUM_SECTION
            for i in range(NUM_SECTION):
                selected_test_data = test_data[((test_data[:, 1] > 0.6 + (INTERVAL*i)) & (test_data[:, 1] <= 0.6 + INTERVAL*(i+1)))].to(device)
                input = torch.cat((updated_train_data, selected_test_data), dim=0).unsqueeze(0).to(device)
                output = transformer(input, updated_train_data.size(0))
                output_points = torch.cat((selected_test_data[:, :2], output.squeeze(0)), dim=1)
                updated_train_data = torch.cat((updated_train_data, output_points), dim=0).to(device)
                if not torch.is_tensor(result): 
                    result=output.squeeze(0)
                    shuffled_solution = selected_test_data
                else: 
                    result = torch.cat((result, output.squeeze(0)), dim=0)
                    shuffled_solution = torch.cat((shuffled_solution, selected_test_data), dim=0)   
            test_solution = shuffled_solution[:, 2:].to(device)
            
        
        loss = torch.mean((test_solution - result)**2)
        
        # Statistics
        with torch.autograd.no_grad():
            transformer.eval()

            L2_error_norm = torch.linalg.norm(result-test_solution, 2, dim = 0)
            L2_true_norm = torch.linalg.norm(test_solution, 2, dim = 0)

            L2_absolute_error = torch.mean(torch.abs(result-test_solution))
            L2_relative_error = L2_error_norm / L2_true_norm
            
            result = result.cpu()
            test_solution = test_solution.cpu()
            
            Max_err = max_error(test_solution, result)
            Ex_var_score = explained_variance_score(test_solution, result)
            
            beta_list.append(beta)
            nu_list.append(nu)
            rho_list.append(rho)
            epsilon_list.append(epsilon)
            theta_list.append(theta)
            test_loss_list.append(loss.item())
            AE_list.append(L2_absolute_error.item())
            RE_list.append(L2_relative_error.item())
            ME_list.append(Max_err)
            VS_list.append(Ex_var_score)
            
            print('=================================================================')  
            print("(beta, nu, rho, epsilon, theta):", beta, nu, rho, epsilon, theta)
            print('Test Loss :', loss.item())
            print('L2_abs_err :', L2_absolute_error.item())
            print('L2_rel_err :', L2_relative_error.item())
            print('Max_error :', Max_err)
            print('Variance_score :', Ex_var_score)                 
            print('=================================================================')
            
            # Save the plot.
            if args.plot:
                predicted_solution = torch.cat((test_data[:, 0:2], result), dim=1) if not args.extrapolation else torch.cat((shuffled_solution[:, 0:2].cpu(), result), dim=1)
                
                full_solution = full_loader.get_all()
                for i in range(predicted_solution.shape[0]):
                    mask = (full_solution[:, 0] == predicted_solution[i, 0]) & (full_solution[:, 1] == predicted_solution[i, 1])
                    full_solution[mask] = predicted_solution[i]
                
                GRAPH_PATH = f"./plot/{system}/{args.u0_str}"
                GRAPH_NAME = f"{beta}_{nu}_{rho}_{epsilon}_{theta}.csv"
                if not os.path.isdir(GRAPH_PATH): os.makedirs(GRAPH_PATH)
                make_csv(full_solution, ['x_data', 't_data', 'u_data'], GRAPH_PATH + "/" + GRAPH_NAME)
       
                

    # Save the log. 
    column = ['beta', 'nu', 'rho', 'epsilon', 'theta', 'test_loss', 'absolute_error', 'relative_error', 'max_error', 'variance_score']
    
    log_list = [beta_list, nu_list, rho_list, epsilon_list, theta_list, test_loss_list, AE_list, RE_list, ME_list, VS_list]
    zipped_list = list(zip(*log_list))
    
    str_analytical = "analytical" if args.numerical else f"PINN_based_{args.PINN_based_prior_ratio}"
    SAVE_PATH = f"./result/{system}/{args.u0_str}"
    SAVE_NAME = f"{args.beta_min}_{args.beta_max}_{args.nu_min}_{args.nu_max}_{args.rho_min}_{args.rho_max}_" +\
        f"{args.epsilon_min}_{args.epsilon_max}_{args.theta_min}_{args.theta_max}_{str_analytical}.csv"
    if not os.path.isdir(SAVE_PATH): os.makedirs(SAVE_PATH)
    make_csv(zipped_list, column, SAVE_PATH + "/" + SAVE_NAME)
    
    
    # Print the average values.
    print('===================== Average Value =============================')  
    print('Test Loss :', statistics.mean(test_loss_list))
    print('L2_abs_err :', statistics.mean(AE_list))
    print('L2_rel_err :', statistics.mean(RE_list))
    print('Max_error :', statistics.mean(ME_list))
    print('Variance_score :', statistics.mean(VS_list))                 
    print('=================================================================')  


if __name__ == "__main__":
    main()

