import os
import sys
import torch
import torch.backends.cudnn as cudnn
import statistics

sys.path.append("..")

from utils import *
from data_gen.utils.dataset_util_2D import generate_domain
from data_gen.PINN.naive_PINN_2D import *

args = parser.parse_args()
# CUDA support
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
def measure(full_analy, full_pinn):
    full_analy = full_analy.reshape(21,1024)
    full_pinn = full_pinn.reshape(21,1024)
    L2_error_norm = torch.linalg.norm(full_pinn-full_analy, 2, dim = -1)
    L2_true_norm = torch.linalg.norm(full_analy, 2, dim = -1)
    L2_relative_error = torch.mean(L2_error_norm / L2_true_norm)
    
    L_inf_error_norm = torch.linalg.norm(full_pinn-full_analy, ord=torch.inf, dim = -1)
    L_inf_true_norm = torch.linalg.norm(full_analy, ord=torch.inf, dim = -1)
    L_inf_relative_error = torch.mean(L_inf_error_norm / L_inf_true_norm)
    
    L2_absolute_error = torch.mean(torch.abs(full_pinn-full_analy))
    
    return L2_absolute_error, L2_relative_error, L_inf_relative_error
    
def generate_dataset():
    
    random_seed = args.seed
    
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    np.random.seed(random_seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(random_seed)
    
    np.random.seed(args.seed)
    inter_given_point = np.array(range(6*32*32))  # Randomly sample 32*32 points from 21*32*32 points
    np.random.shuffle(inter_given_point)
    
    # Determine parameter intervals.
    beta_interval = int((args.beta_max - args.beta_min) // args.beta_step)
    beta_y_interval = int((args.beta_y_max - args.beta_y_min) // args.beta_step)
    nu_interval = int((args.nu_max - args.nu_min) // args.nu_step)
    nu_y_interval = int((args.nu_y_max - args.nu_y_min) // args.nu_step)
    rho_interval = int((args.rho_max - args.rho_min) // args.rho_step)
    epsilon_interval = int((args.epsilon_max - args.epsilon_min) // args.epsilon_step)
    theta_interval = int((args.theta_max - args.theta_min) // args.theta_step)

    # Bring a domain.
    
    AE_list = []
    RE_list = []
    IRE_list = []
    test_pde_list = []
    
    # For each beta, nu, rho, epsilon and theta, generate the prior.
    i, i_y, j, j_y, k, l, m = 0, 0, 0, 0, 0, 0, 0
    for i in range(beta_interval + 1):
        for i_y in range(beta_y_interval + 1):
            for j in range(nu_interval + 1):
                for j_y in range(nu_y_interval + 1):
                    for k in range(rho_interval + 1):
                        for l in range(epsilon_interval + 1):
                            for m in range(theta_interval + 1):
                        
                                beta = args.beta_min + (i * args.beta_step)
                                beta_y = args.beta_y_min + (i_y * args.beta_step)
                                nu = args.nu_min + (j * args.nu_step)
                                nu_y = args.nu_y_min + (j_y * args.nu_step)
                                rho = args.rho_min + (k * args.rho_step)
                                epsilon = args.epsilon_min + (l * args.epsilon_step)
                                theta = args.theta_min + (m * args.theta_step)
                                
                                test_pde_list.append((beta, beta_y, nu, nu_y, rho, epsilon, theta))
                                
    high_error = []
    thres = 0.1
                                
    for item in test_pde_list:    
        beta, beta_y, nu, nu_y, rho, epsilon, theta = item
        
        save_path_root = f"./dataset/2D_cdr"
        file_name_analy = f"{beta}_{beta_y}_{nu}_{nu_y}_{rho}_{epsilon}_{theta}_analytical.npy"
        file_name_pinn = f"{beta}_{beta_y}_{nu}_{nu_y}_{rho}_{epsilon}_{theta}_PINN_based.npy"
        
        full_data_pinn = torch.tensor(np.load(f"{save_path_root}/full/{file_name_pinn}")).to(device)
        full_data_analy = torch.tensor(np.load(f"{save_path_root}/full/{file_name_analy}")).to(device)
        
        L2_absolute_error, L2_relative_error, L_inf_relative_error = measure(full_data_analy, full_data_pinn)
        
        if L2_relative_error > thres:
            print(item, "<------------------------")
            high_error.append(item)
            
            X_star, X_star_noboundary_noinitial, X_star_initial, X_star_boundary = generate_domain()
            X_star, X_star_noboundary_noinitial, X_star_initial, X_star_boundary = X_star.to(device, dtype=torch.float), \
                X_star_noboundary_noinitial.to(device, dtype=torch.float), X_star_initial.to(device, dtype=torch.float), X_star_boundary.to(device, dtype=torch.float)
            
            
            DOMAIN_LIST = [X_star, X_star_noboundary_noinitial, X_star_initial, X_star_boundary]
            adding=0
            
            pinn_epoch = args.PINN_epoch
            # naive PINN
            while L2_relative_error > thres:
                args.PINN_batch += adding*1000
                pinn_epoch      += adding*10
                adding          += 1
                pinn = NaivePINN(beta, beta_y, nu, nu_y, rho, epsilon, theta)
                result = pinn.train(DOMAIN_LIST, pinn_epoch)
                if not (result > 0):
                    continue
                full_data = pinn.predict(X_star[:, 0:1], X_star[:, 1:2], X_star[:, 2:3]).to(device)
                full_data = torch.tensor(full_data, dtype=torch.float32)         # 21*32*32
                L2_absolute_error, L2_relative_error, L_inf_relative_error = measure(full_data_analy, full_data)
                
                if adding > 5:
                    break
            
            full_data = full_data.reshape(args.xgrid, args.xgrid, args.nt, 1).permute(2, 0, 1, 3) # 21, 32, 32, 1
            np.save(f"{save_path_root}/full/{file_name_pinn}", full_data.cpu().numpy())
            
            # Train data
            train_data = full_data[::2][:6]     # 6, 32, 32, 1
            np.save(f"{save_path_root}/train/{file_name_pinn}", train_data.cpu().numpy())
            
            # Validation data
            valid_extra = full_data[::2][5:7]       # 2, 32, 32, 1
            
            inter_ctx = train_data.reshape(1,-1,1)                    # 1, 6*32*32, 1
            inter_ctx = inter_ctx[:,inter_given_point[:32*32]]            # 1, 32*32, 1
            valid_inter = torch.concat((inter_ctx,full_data[5:6].reshape(1,-1,1)),dim=1)            # 1, 32*32+32*32, 1
            
            np.save(f"{save_path_root}/valid/extrapolation/{file_name_pinn}", valid_extra.cpu().numpy())
            np.save(f"{save_path_root}/valid/interpolation/{file_name_pinn}", valid_inter.cpu().numpy())
            
            # test data
            test_extra = full_data[::2][6:]       # 5, 32, 32, 1
            test_inter = torch.concat((inter_ctx,full_data[[1,3,7,9]].reshape(1,-1,1)),dim=1)              # 1, 32*32+4*32*32, 1
            
            np.save(f"{save_path_root}/test/extrapolation/{file_name_pinn}", test_extra.cpu().numpy())
            np.save(f"{save_path_root}/test/interpolation/{file_name_pinn}", test_inter.cpu().numpy())
            
            # print("Regenerate Done!")
            # print(L2_absolute_error, L2_relative_error, L_inf_relative_error)
        else:
            print(item)
        AE_list.append(L2_absolute_error.item())
        RE_list.append(L2_relative_error.item())
        IRE_list.append(L_inf_relative_error.item())
    
    print(f'===================== error over {thres} items list=============================')  
    for ind in high_error:
        print(ind)
    print('===================== Low error items Average Value =============================')  
    print('L2_abs_err :', statistics.mean(AE_list))
    print('L2_rel_err :', statistics.mean(RE_list))
    print('L_inf_rel_err :', statistics.mean(IRE_list))             
    print('=================================================================')  

if __name__ == "__main__":
    generate_dataset()