import os
import sys
import torch
import torch.backends.cudnn as cudnn

sys.path.append("..")

from utils import *
from data_gen.utils.spectral_method import convection_diffusion_reaction_discrete_solution_2D
from data_gen.utils.dataset_util_2D import generate_domain
from data_gen.PINN.naive_PINN_2D import *

args = parser.parse_args()

FILE_ABS_PATH = os.path.abspath(__file__)
FILE_DIR_PATH = os.path.dirname(FILE_ABS_PATH)

# CUDA support
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
SPATIAL_SIZE = math.pi


def generate_dataset():
    random_seed = args.seed
    
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    np.random.seed(random_seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(random_seed)
    
    np.random.seed(args.seed)
    inter_given_point = np.array(range(6*32*32))  # Randomly sample 32*32 points from 21*32*32 points
    np.random.shuffle(inter_given_point)
    
    print("=============[Deivce Info]==============")
    print("- Use Device :", device)
    print("- Available cuda devices :", torch.cuda.device_count())
    print("- Current cuda device :", torch.cuda.current_device())
    print("- Name of cuda device :", torch.cuda.get_device_name(device))
    print("========================================\n")
    
    print(f"======================= Generation Setting =======================")
    print(f"Spatial Domain Size : [0,{math.pi}] * [0,{math.pi}] * [0.0, 1.0]")
    print(f"Grid number         : {args.xgrid}, {args.xgrid}, {args.nt}")
    if not args.numerical:
        print(f"\nGenerate using PINN")
        print(f"PINN batch number : {args.PINN_batch}")
    print(f"===================================================================")    
    
    # Determine parameter intervals.
    beta_interval = int((args.beta_max - args.beta_min) // args.beta_step)
    beta_y_interval = int((args.beta_y_max - args.beta_y_min) // args.beta_step)
    nu_interval = int((args.nu_max - args.nu_min) // args.nu_step)
    nu_y_interval = int((args.nu_y_max - args.nu_y_min) // args.nu_step)
    rho_interval = int((args.rho_max - args.rho_min) // args.rho_step)
    epsilon_interval = int((args.epsilon_max - args.epsilon_min) // args.epsilon_step)
    theta_interval = int((args.theta_max - args.theta_min) // args.theta_step)

    # Bring a domain.
    from data_gen.utils.dataset_util_2D import generate_domain
    X_star, X_star_noboundary_noinitial, X_star_initial, X_star_boundary = generate_domain()
    X_star, X_star_noboundary_noinitial, X_star_initial, X_star_boundary = \
        X_star.to(device), X_star_noboundary_noinitial.to(device), X_star_initial.to(device), X_star_boundary.to(device)
    
    test_pde_list = []
    
    # For each beta, nu, rho, epsilon and theta, generate the prior.
    i, i_y, j, j_y, k, l, m = 0, 0, 0, 0, 0, 0, 0
    for i in range(beta_interval + 1):
        for i_y in range(beta_y_interval + 1):
            for j in range(nu_interval + 1):
                for j_y in range(nu_y_interval + 1):
                    for k in range(rho_interval + 1):
                        for l in range(epsilon_interval + 1):
                            for m in range(theta_interval + 1):
                        
                                beta = args.beta_min + (i * args.beta_step)
                                beta_y = args.beta_y_min + (i_y * args.beta_step)
                                nu = args.nu_min + (j * args.nu_step)
                                nu_y = args.nu_y_min + (j_y * args.nu_step)
                                rho = args.rho_min + (k * args.rho_step)
                                epsilon = args.epsilon_min + (l * args.epsilon_step)
                                theta = args.theta_min + (m * args.theta_step)
                                
                                test_pde_list.append((beta, beta_y, nu, nu_y, rho, epsilon, theta))
        
    save_path_root = f"{FILE_DIR_PATH}/dataset/2D_cdr"
    str_analytical = "analytical" if args.numerical else "PINN_based"
    generate_domain = False
    for num, item in enumerate(test_pde_list):
        # Specify beta, nu, rho value.
        beta, beta_y, nu, nu_y, rho, epsilon, theta = item
        file_name = f"{beta}_{beta_y}_{nu}_{nu_y}_{rho}_{epsilon}_{theta}_{str_analytical}.npy"
        
        if (not args.total) and os.path.isfile(f"{save_path_root}/full/{file_name}"):
            print(f"Passing the already generated case : {beta},{beta_y},{nu},{nu_y},{rho},{epsilon},{theta}")
            continue
                
        # Numerical prior
        if args.numerical:
            full_data, u_v = convection_diffusion_reaction_discrete_solution_2D('sin', beta, beta_y, nu, nu_y, rho, epsilon, theta, 
                                                                             nx = args.xgrid, ny = args.xgrid, nt= args.nt, total_nt = 1001)
        
        # PINN prior
        else:
            DOMAIN_LIST = [X_star, X_star_noboundary_noinitial, X_star_initial, X_star_boundary]
            
            # naive PINN
            for redo in range(10):
                pinn = NaivePINN(beta, beta_y, nu, nu_y, rho, epsilon, theta)
                result = pinn.train(DOMAIN_LIST, args.PINN_epoch)
                if result > 0: break
            full_data = pinn.predict(X_star[:, 0:1].to(device, dtype=torch.float), 
                                  X_star[:, 1:2].to(device, dtype=torch.float),
                                  X_star[:, 2:3].to(device, dtype=torch.float)).to(device)
            
        
        # Full data
        full_data = torch.tensor(full_data, dtype=torch.float32)         # 21*32*32
        full_data = full_data.reshape(args.xgrid, args.xgrid, args.nt, 1).permute(2, 0, 1, 3) # 21, 32, 32, 1
        np.save(f"{save_path_root}/full/{file_name}", full_data.cpu().numpy())
        
        # Train data
        train_data = full_data[::2][:6]     # 6, 32, 32, 1
        np.save(f"{save_path_root}/train/{file_name}", train_data.cpu().numpy())
        
        # Validation data
        valid_extra = full_data[::2][5:7]       # 2, 32, 32, 1
        
        inter_ctx = train_data.reshape(1,-1,1)                    # 1, 6*32*32, 1
        inter_ctx = inter_ctx[:,inter_given_point[:32*32]]            # 1, 32*32, 1
        valid_inter = torch.concat((inter_ctx,full_data[5:6].reshape(1,-1,1)),dim=1)            # 1, 32*32+32*32, 1
        
        np.save(f"{save_path_root}/valid/extrapolation/{file_name}", valid_extra.cpu().numpy())
        np.save(f"{save_path_root}/valid/interpolation/{file_name}", valid_inter.cpu().numpy())
        
        # test data
        test_extra = full_data[::2][6:]       # 5, 32, 32, 1
        test_inter = torch.concat((inter_ctx,full_data[[1,3,7,9]].reshape(1,-1,1)),dim=1)              # 1, 32*32+4*32*32, 1
        
        np.save(f"{save_path_root}/test/extrapolation/{file_name}", test_extra.cpu().numpy())
        np.save(f"{save_path_root}/test/interpolation/{file_name}", test_inter.cpu().numpy())
        
        if not generate_domain:
            # Full domain
            X_star_PINN = X_star.reshape(32, 32, 21, 3).permute(2, 0, 1, 3)    # 21, 32, 32, 3
            np.save(f"{save_path_root}/full_domain.npy", X_star.cpu().numpy())
            
            # Train domain
            train_domain = X_star_PINN[::2][:6]      # 6, 32, 32, 3
            np.save(f"{save_path_root}/train_domain.npy", train_domain.cpu().numpy())
            np.save(f"{save_path_root}/inter_given_index.npy", inter_given_point)
            
            # Validation domain
            valid_extra_domain = X_star_PINN[::2][5:7]   # 2, 32, 32, 3
            inter_ctx_domain = train_domain.reshape(1,-1,3)     # 1, 6*32*32, 3
            inter_ctx_domain = inter_ctx_domain[:,inter_given_point[:32*32]]     # 1, 32*32, 3
            valid_inter_domain = torch.concat((inter_ctx_domain,X_star_PINN[5:6].reshape(1,-1,3)),dim=1) # 1, 32*32+32*32, 3
            np.save(f"{save_path_root}/valid_extra_domain.npy", valid_extra_domain.cpu().numpy())
            np.save(f"{save_path_root}/valid_inter_domain.npy", valid_inter_domain.cpu().numpy())
            
            # Test domain
            test_extra_domain = X_star_PINN[::2][6:]   # 5, 32, 32, 3
            test_inter_domain = torch.concat((inter_ctx_domain,X_star_PINN[[1,3,7,9]].reshape(1,-1,3)),dim=1)   # 1, 32*32+4*32*32, 3
            np.save(f"{save_path_root}/test_extra_domain.npy", test_extra_domain.cpu().numpy())
            np.save(f"{save_path_root}/test_inter_domain.npy", test_inter_domain.cpu().numpy())
            generate_domain = True
        
        print("Done!", beta, beta_y, nu, nu_y, rho, epsilon, theta)

    print("Generating dataset is done.")

if __name__ == "__main__":
    generate_dataset()