import torch
import torch.nn as nn
import numpy as np
import random
import torch.backends.cudnn as cudnn
import os
import time
from torch.optim import lr_scheduler
import torch._dynamo
torch._dynamo.config.suppress_errors = True
import warnings
warnings.filterwarnings("ignore") 

from models import PDE_PFN
from utils import *

args = parser.parse_args()

FILE_ABS_PATH = os.path.abspath(__file__)
FILE_DIR_PATH = os.path.dirname(FILE_ABS_PATH)

# CUDA support
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

def Downstream_main():
    
    args = get_config()
    random_seed = args.seed
    
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    np.random.seed(random_seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(random_seed)

    print("=============[Deivce Info]==============")
    print("- Use Device :", device)
    print("- Available cuda devices :", torch.cuda.device_count())
    print("- Current cuda device :", torch.cuda.current_device())
    print("- Name of cuda device :", torch.cuda.get_device_name(device))
    if not args.numerical:
        print("- Use PINN base prior")
        print("  - Ratio of PINN based prior :", f"{args.PINN_based_prior_ratio}%")
    print("========================================\n")
    
    #################### Data Loader ####################
    test_loader, n_in, n_out, single_eval_pos = data_provider(args, flag='test')
    
    if isinstance(single_eval_pos, list):
        xgrid = single_eval_pos[0]
        args.xgrid = xgrid
        grid_num = np.prod(single_eval_pos)
        single_eval_pos = grid_num
    #####################################################
    
    #################### Model #################### 
    transformer = PDE_PFN.TransformerModel(n_in = n_in, n_out = n_out, args = args).to(device)
    # Pytorch compilation.
    transformer = torch.compile(transformer, fullgraph=False)
    # Load weights for testing.
    transformer = net_operator_ours(transformer, device, data=args.data, mode="load")
    model_size = param_number(transformer)
    ###############################################
    
    task = task_specifier(args)
    print("=============[Test Info]===============")
    print(f"- Data               : {args.data}")
    print(f"- Task               : {task}")
    print(f"- Input feature dim  : {n_in}")
    print(f"- output feature dim : {n_out}")
    print(f"- Grid/Mesh Number   : {single_eval_pos}")
    print(f"- Train data num     : {args.N_train_shot}")
    if args.N_test != 0:
        print(f"- Test data num      : {args.N_test}")
    print(f"- Model size         : {model_size}")
    print("========================================\n")
    
    # metric
    test_loss = []
    test_l1_abs_loss = []
    test_l2_rel_loss = []
    test_linf_rel_loss = []
    
    #### Train + Vaidation
    with torch.autograd.no_grad():
        transformer.eval()
        for ind, batch_data in enumerate(test_loader):
            if "2D_cdr" in args.data:
                ind, batch_data = batch_data
            # batch_data (2D) : 1, t*grid*grid, in_feature
            
            test_solution = batch_data[:, single_eval_pos:, -n_out:].to(device)     # 3, 5233, 4
            # solution shape : 1, t*grid*grid-single_eval_pos, out_feature
            test_solution = test_solution.reshape(-1, grid_num, n_out)
            # solution shape : target length, grid*grid, out_feature
            
            output, _ = transformer(batch_data, single_eval_pos)
            output = output.reshape(-1, grid_num, n_out) # target length, grid*grid, out_feature
            
            loss = torch.mean((test_solution - output)**2)

            L1_abs_error = torch.mean(torch.abs(output-test_solution))

            L2_error_norm = torch.linalg.norm(output-test_solution, 2, dim = 1)   
            L2_true_norm = torch.linalg.norm(test_solution, 2, dim = 1)           
            L2_rel_error = torch.mean(L2_error_norm / L2_true_norm)
            
            L_inf_error_norm = torch.linalg.norm(output-test_solution, ord=torch.inf, dim = 1)
            L_inf_true_norm = torch.linalg.norm(test_solution, ord=torch.inf, dim = 1)        
            L_inf_rel_error = torch.mean(L_inf_error_norm / L_inf_true_norm)

            test_loss.append(loss.item())
            test_l1_abs_loss.append(L1_abs_error.item())
            test_l2_rel_loss.append(L2_rel_error.item())
            test_linf_rel_loss.append(L_inf_rel_error.item())
            
            print('=================================================================')  
            print('ind :', ind)
            print('Test Loss :', loss.item())
            print('L1_abs_err :', L1_abs_error.item())
            print('L2_rel_err :', L2_rel_error.item())
            print('L_inf_rel_err :', L_inf_rel_error.item())
            print('=================================================================')  
    
    test_loss = np.average(test_loss)
    test_l1_abs_loss = np.average(test_l1_abs_loss)
    test_l2_rel_loss = np.average(test_l2_rel_loss)
    L_inf_rel_error = np.average(test_linf_rel_loss)
    
    exp_results = {"MSE" : test_loss, "L1_abs" : test_l1_abs_loss, 
                   "L2_rel" : test_l2_rel_loss, "Linf_rel" : L_inf_rel_error}
    save_results(args, exp_results)
            
    print('===================== Average Value =============================')  
    print('Test Loss :', test_loss)
    print('L1_abs_err :', test_l1_abs_loss)
    print('L2_rel_err :', test_l2_rel_loss)
    print('L_inf_rel_err :', L_inf_rel_error)
    print('=================================================================')  

if __name__ == "__main__":
    Downstream_main()