import torch
import torch.nn as nn
import numpy as np
import random
import torch.backends.cudnn as cudnn
import os
import time
from torch.optim import lr_scheduler
from torch.cuda.amp import autocast, GradScaler
import torch._dynamo
torch._dynamo.config.suppress_errors = True
import warnings
warnings.filterwarnings("ignore") 
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")

from models import PDE_PFN
from utils import *

args = parser.parse_args()

FILE_ABS_PATH = os.path.abspath(__file__)
FILE_DIR_PATH = os.path.dirname(FILE_ABS_PATH)

# CUDA support
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

def Train_main():
    
    args = get_config()
    random_seed = args.seed
    
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    np.random.seed(random_seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(random_seed)

    print("=============[Deivce Info]==============")
    print("- Use Device :", device)
    print("- Available cuda devices :", torch.cuda.device_count())
    print("- Current cuda device :", torch.cuda.current_device())
    print("- Name of cuda device :", torch.cuda.get_device_name(device))
    if not args.numerical:
        print("- Use PINN base prior")
        print("  - Ratio of PINN based prior :", f"{args.PINN_based_prior_ratio}%")
    print("========================================\n")
    
    #################### Data Loader ####################
    train_loader, n_in, n_out, single_eval_pos = data_provider(args, flag='train')
    valid_loader, _, _, _ = data_provider(args, flag='valid')
    
    if isinstance(single_eval_pos, list):
        xgrid = single_eval_pos[0]
        args.xgrid = xgrid
        single_eval_pos = np.prod(single_eval_pos)
    #####################################################
    
    #################### Model, Optimizer #################### 
    transformer = PDE_PFN.TransformerModel(n_in = n_in, n_out = n_out, args = args).to(device)
    # Pytorch compilation.
    transformer = torch.compile(transformer, mode="reduce-overhead", fullgraph=False)
    # Load pretrained weights when finetuning.
    if "2D_cdr" not in args.data:
        transformer = net_operator_ours(transformer, device, data="2D_cdr", mode="finetuning", n_in = n_in)
    
    # Optimizer and scheduler.
    model_size = param_number(transformer)
    optimizer = torch.optim.AdamW(transformer.parameters(), lr=args.lr, weight_decay=1e-6, fused=True)
    scheduler_method = construct_scheduler(optimizer, args)
    ###########################################################
    
    epoch = args.epoch*args.N_train_shot
    valid_best_loss = torch.inf
    print("=============[Train Info]===============")
    print(f"- Data               : {args.data}")
    print(f"- Input feature dim  : {n_in}")
    print(f"- output feature dim : {n_out}")
    print(f"- Mesh Number        : {single_eval_pos}")
    print(f"- Train data num     : {args.N_train_shot}")
    print(f"- Batch size         : {args.batch}")
    print(f"- Model size         : {model_size}")
    print(f"- scheduler          : {args.scheduler}")
    print("========================================\n")
    
    #### Start time.
    start_time = time.time()
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))
    
    #### Train + Vaidation
    for EPOCH in range(args.epoch):
        train_loss = []
        transformer.train()
        
        #### Train
        for batch_data in train_loader:
            batch_data = torch.flatten(batch_data, start_dim=0, end_dim=1).to(device)
            # batch_data (2D) : batch*(t-1), 2*grid*grid, in_feature
            
            train_solution = batch_data[:, single_eval_pos:, -n_out:].to(device)
            # solution shape : batch, t*grid*grid - single_eval_pos, out_feature
            
            #### iteration 
            with torch.cuda.amp.autocast(enabled=(device.type=="cuda"), dtype=torch.bfloat16):
                output, ortho_loss_list = transformer(batch_data, single_eval_pos)
                ortho_loss = torch.mean(ortho_loss_list)
                loss = torch.mean((train_solution - output)**2) + (1/(single_eval_pos))*ortho_loss    # 12, 5233, 4, 4

            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
            if args.scheduler=="OneCycleLR":
                scheduler_method.step()
            
            train_loss.append(loss.item())

            ###################################################################
            # gpu_memory = torch.cuda.memory_allocated()  
            # gpu_cached = torch.cuda.memory_reserved()   
            # print(f"Allocated GPU memory: {gpu_memory / 1024**2:.2f} MB")
            # print(f"Cached GPU memory: {gpu_cached / 1024**2:.2f} MB")
            # print('Time consumed (s):', round(time.time() - start_time, 2))
            # import pdb ; pdb.set_trace()
            ###################################################################
            
        if args.scheduler!="OneCycleLR":
            scheduler_method.step()
            print(f"Parameter update with scheduler : {scheduler_method.get_last_lr()}")
        
        #### Validation
        valid_loss = []
        valid_l1_abs_loss = []
        valid_l2_rel_loss = []
        valid_linf_rel_loss = []
        
        with torch.autograd.no_grad():
            transformer.eval()
            for batch_data in valid_loader:
                # batch_data (2D) : batch, task number, t*grid*grid, in_feature
                
                batch_data = torch.flatten(batch_data, start_dim=0, end_dim=1).to(device)
                # input shape : batch*task number, t*grid*grid, in_feature
                
                valid_solution = batch_data[:, single_eval_pos:, -n_out:].reshape(-1, single_eval_pos, n_out).to(device)
                # solution shape : batch*task number*(t-1), grid*grid, in_feature
                
                output, _ = transformer(batch_data, single_eval_pos)
                output = output.reshape(-1, single_eval_pos, n_out)
                
                loss = torch.mean((valid_solution - output)**2)

                L2_error_norm = torch.linalg.norm(output-valid_solution, 2, dim = 1)   
                L2_true_norm = torch.linalg.norm(valid_solution, 2, dim = 1)           
                L2_rel_error = torch.mean(L2_error_norm / L2_true_norm)
                
                L_inf_error_norm = torch.linalg.norm(output-valid_solution, ord=torch.inf, dim = 1)
                L_inf_true_norm = torch.linalg.norm(valid_solution, ord=torch.inf, dim = 1)        
                L_inf_rel_error = torch.mean(L_inf_error_norm / L_inf_true_norm)

                valid_loss.append(loss.item())
                valid_l1_abs_loss.append(torch.mean(torch.abs(output-valid_solution)).item())
                valid_l2_rel_loss.append(L2_rel_error.item())
                valid_linf_rel_loss.append(L_inf_rel_error.item())
            
            valid_loss = np.average(valid_loss)
            L1_abs_error = np.average(valid_l1_abs_loss)
            L2_rel_error = np.average(valid_l2_rel_loss)
            L_inf_rel_error = np.average(valid_linf_rel_loss)
                
            print('=================================================================')  
            print('EPOCH :', EPOCH + 1)
            print('train Loss :', np.average(train_loss))
            print('valid Loss :', valid_loss)
            print('valid Best :', valid_loss < valid_best_loss)
            print('L1_abs_err :', L1_abs_error)
            print('L2_rel_err :', L2_rel_error)
            print('L_inf_rel_err :', L_inf_rel_error)
            print('Time consumed (s):', round(time.time() - start_time, 2))
            print('=================================================================')  
            
            if valid_loss < valid_best_loss:
                net_operator_ours(transformer, device, data=args.data, mode="save")
                valid_best_loss = valid_loss
            
        
    print('Epoch number :', EPOCH+1)
    print('=================================================================')

if __name__ == "__main__":
    Train_main()