import os
import numpy as np
import pandas as pd
import os
import random
import torch
from torch.utils.data import Dataset, DataLoader
import warnings

warnings.filterwarnings('ignore')


######################################## For Darcy dataset ########################################
# From PDEbench - forcing term 0.1
class Dataset_Darcy(Dataset):
    def __init__(self, args, flag='train'):
        self.root_path = os.path.abspath("./data_gen/dataset/Darcy/beta0.1")
        assert flag in ['train', 'test', 'valid']
        self.flag = flag
        self.extrapolation = args.extrapolation
        self.system = args.system
        self.N_train = args.N_train_shot
        self.N_valid = args.N_valid
        self.N_test = args.N_test
        self.total = args.total      # Method in train
        self.__read_data__()

    def __read_data__(self):
        self.Total_input  = torch.tensor(np.load(f"{self.root_path}/nu.npy"), dtype=torch.float32).unsqueeze(-1)     # 10000, 128, 128, 1
        self.Total_output = torch.tensor(np.load(f"{self.root_path}/tensor.npy"), dtype=torch.float32).unsqueeze(-1) # 10000, 128, 128, 1
        
        x_grid = np.load(f"{self.root_path}/x_coordinate.npy")      # 128,
        y_grid = np.load(f"{self.root_path}/y_coordinate.npy")      # 128,
        t_grid = np.array([0.0,1.0])                                # 2
        
        t_grid, x_grid, y_grid = np.meshgrid(t_grid,x_grid,y_grid,indexing='ij')
        t_grid = torch.tensor(t_grid, dtype=torch.float32).unsqueeze(-1)                     # 2, 128, 128, 1
        x_grid = torch.tensor(x_grid, dtype=torch.float32).unsqueeze(-1)                     # 2, 128, 128, 1
        y_grid = torch.tensor(y_grid, dtype=torch.float32).unsqueeze(-1)                     # 2, 128, 128, 1
        self.Domain = torch.concat([x_grid, y_grid, t_grid],dim=-1)     # 2, 128, 128, 3
        self.Domain_input = self.Domain[:1]     # 1, 128, 128, 3
        self.Domain_output = self.Domain[1:]    # 1, 128, 128, 3
        
        if self.flag == 'train':
            self.Total_input = self.Total_input[:self.N_train]                              # N_train, 128, 128, 1
            self.Total_output = self.Total_output[:self.N_train]                            # N_train, 128, 128, 1
        elif self.flag == 'valid':
            self.Total_input = self.Total_input[self.N_train:self.N_train+self.N_valid]     # N_valid, 128, 128, 1
            self.Total_output = self.Total_output[self.N_train:self.N_train+self.N_valid]   # N_train, 128, 128, 1
        else:
            self.Total_input = self.Total_input[self.N_train+self.N_valid:self.N_train+\
                                        self.N_valid+self.N_test]                           # N_test, 128, 128, 1
            self.Total_output = self.Total_output[self.N_train+self.N_valid:self.N_train+\
                                        self.N_valid+self.N_test]                           # N_test, 128, 128, 1

    def __getitem__(self, index):
        input = self.Total_input[index:index+1]
        input = torch.concat([self.Domain_input,input], dim=-1)     # 1, 128, 128, 4
        
        output = self.Total_output[index:index+1]
        output = torch.concat([self.Domain_output,output], dim=-1)  # 1, 128, 128, 4
        
        data = torch.concat([input, output],dim=0)                  # 2, 128, 128, 4
        # data feature : spatial domain(2), time domain(1), target(1)
        
        if self.total:
            data = data.reshape(1,-1,4)                 # 1, 2*128*128, 4
        return data

    def __len__(self):
        return len(self.Total_input)
    

######################################## For CNSE dataset ######################################## 
# From PDEbench : 2D_CFD_Rand_M1.0_Eta0.1_Zeta0.1_periodic_128
class Dataset_CNSE(Dataset):
    def __init__(self, args, flag='train'):
        # init
        self.root_path = os.path.abspath("./data_gen/dataset/CNSE")
        assert flag in ['train', 'test', 'valid']
        self.flag = flag
        self.system = args.system
        self.N_train = args.N_train_shot
        self.N_valid = args.N_valid
        self.N_test = args.N_test
        self.extrapolation = args.extrapolation
        self.total = args.total      # Method in train
        self.__read_data__()

    def __read_data__(self):
        Total_Target_Vx = torch.tensor(np.load(os.path.join(self.root_path, "Vx.npy"))).unsqueeze(-1)             # 1000, 21, 64, 64, 1
        Total_Target_Vy = torch.tensor(np.load(os.path.join(self.root_path, "Vy.npy"))).unsqueeze(-1)             # 1000, 21, 64, 64, 1
        Total_Target_density = torch.tensor(np.load(os.path.join(self.root_path, "density.npy"))).unsqueeze(-1)   # 1000, 21, 64, 64, 1
        Total_Target_pressure = torch.tensor(np.load(os.path.join(self.root_path, "pressure.npy"))).unsqueeze(-1) # 1000, 21, 64, 64, 1
        self.target = torch.concat([Total_Target_Vx,Total_Target_Vy,Total_Target_density,Total_Target_pressure],dim=-1)[:,::5] # 1000, 5, 64, 64, 4
        
        x_grid = torch.tensor(np.load(os.path.join(self.root_path, "x-coordinate.npy")))        # 64
        y_grid = torch.tensor(np.load(os.path.join(self.root_path, "y-coordinate.npy")))        # 64
        t_grid = torch.tensor(np.load(os.path.join(self.root_path, "t-coordinate.npy")))[::5]   # 21 -> 5
        
        t_grid, x_grid, y_grid = np.meshgrid(t_grid,x_grid,y_grid,indexing='ij')
        t_grid = torch.tensor(t_grid).unsqueeze(-1)
        x_grid = torch.tensor(x_grid).unsqueeze(-1)
        y_grid = torch.tensor(y_grid).unsqueeze(-1)
        self.domain = torch.concat([x_grid, y_grid, t_grid],dim=-1)    # 5, 64, 64, 3

        # Total : [0, 20]
        # Given 0 - Predict 5, 10, 15, 20
        
        if self.flag == 'train':
            self.target = self.target[:self.N_train]                                # N_train, 5, 64, 64, 4
        elif self.flag == 'valid':
            self.target = self.target[self.N_train:self.N_train+self.N_valid]       # N_valid, 5, 64, 64, 4
        else:
            self.target = self.target[self.N_train+self.N_valid:self.N_train+self.N_valid+self.N_test]  # N_test, 5, 64, 64, 4

    def __getitem__(self, index):
        input = torch.concat([self.domain, self.target[index]],dim=-1)  # 5, 64, 64, 7
        # data feature : spatial domain(2), time domain(1), target(4)
        if self.total==1:
            input = input.reshape(1,-1,7)                                   # 1, 5*64*64, 7
        return input

    def __len__(self):
        return len(self.target)


######################################## For SWE dataset ######################################## 
# From PDEbench
class Dataset_SWE(Dataset):
    def __init__(self, args, flag='train'):
        # init
        self.root_path = os.path.abspath("./data_gen/dataset/SWE")
        assert flag in ['train', 'test', 'valid']
        self.flag = flag
        if flag == 'test':
            assert args.system in ['Seen dataset', 'Unseen dataset']  # Seen / Unseen task in test
        self.system = args.system
        self.N_train = args.N_train_shot
        self.extrapolation = args.extrapolation
        self.total = args.total      # Method in train
        self.__read_data__()

    def __read_data__(self):
        Total_Target=[]
        for i in range(1000):
            data_i = torch.tensor(np.load(f"{self.root_path}/data_{i}.npy")).unsqueeze(0)
            Total_Target.append(data_i)
        
        subsample = range(0,101,5)
        Total_Target = torch.concat(Total_Target)
        Total_Target = Total_Target[:,subsample]      # N_train, 21, 128, 128, 1
        
        t_grid = np.load(f"{self.root_path}/t_grid.npy")[subsample]
        x_grid = np.load(f"{self.root_path}/x_grid.npy")
        y_grid = np.load(f"{self.root_path}/y_grid.npy")
        
        t_grid, x_grid, y_grid = np.meshgrid(t_grid,x_grid,y_grid,indexing='ij')
        t_grid = torch.tensor(t_grid).unsqueeze(-1)
        x_grid = torch.tensor(x_grid).unsqueeze(-1)
        y_grid = torch.tensor(y_grid).unsqueeze(-1)
        Total_Domain = torch.concat([x_grid, y_grid, t_grid],dim=-1)    # 21, 128, 128, 3

        # Total : [0, 10]
        # Train : [0, 5]
        # Valid : given 5 - eval 6
        # Test  : given 6 - eval 7, 8, 9, 10

        if self.flag == 'train':
            self.target = Total_Target[:self.N_train]       # N_train, 21, 128, 128, 1
            self.target = self.target[:,:12:2]              # N_train, 6, 128, 128, 1
            self.domain = Total_Domain[:12:2]               # 6, 128, 128, 3
        elif self.flag == 'valid':
            self.target = Total_Target[:self.N_train]       # N_train, 21, 128, 128, 1
            if self.total==1:
                self.target = self.target[:,[0,2,4,6,8,10,5,12]]   # N_train, 7, 128, 128, 1
                self.domain = Total_Domain[[0,2,4,6,8,10,5,12]]    # 7, 128, 128, 3
            else:
                self.target = self.target[:,[10,12]]        # N_train, 2, 128, 128, 1
                self.domain = Total_Domain[[10,12]]         # 2, 128, 128, 3
            
        else:
            if self.system == "Seen dataset":
                self.target = Total_Target[:self.N_train]   # N_train, 21, 128, 128, 1
            elif self.system == "Unseen dataset":
                self.target = Total_Target[self.N_train:]   # 1000-N_train, 21, 128, 128, 1
                
            if self.extrapolation:
                self.target = self.target[:, 12::2]          # : , 5, 128, 128, 1
                self.domain = Total_Domain[12::2]            # 5, 128, 128, 3
            elif not self.extrapolation:
                self.target = self.target[:,[0,2,4,6,8,10,1,3,7,9]] # N_train, 10, 128, 128, 1
                self.domain = Total_Domain[[0,2,4,6,8,10,1,3,7,9]]  # 10, 128, 128, 3
                
        
        self.inter_point = np.load(f"{self.root_path}/inter_given_index.npy")

    def __getitem__(self, index):
        input = torch.concat([self.domain, self.target[index]],dim=-1)  # t, 128, 128, 4
        # data feature : spatial domain(2), time domain(1), target(1)
        
        if self.total==2:   # DeepONet extrapolation
            if self.flag == 'train':
                input = input.reshape(6,-1,4)                           # 6, 128*128, 4
                input = np.concatenate((input[:-1],input[1:]), axis=1)  # 5, 2*128*128, 4
            else:
                input = input.reshape(1,-1,4)                           # 1, 2(5)*128*128, 4
        
        
        elif self.total==1:     # Ours, DeepONet interpolation
            if self.flag == 'train':
                if self.extrapolation:  # Ours training
                    input = input.reshape(6,-1,4)                           # 6, 128*128, 4
                    input = np.concatenate((input[:-1],input[1:]), axis=1)  # 5, 2*32*32, 4
                    perm = torch.randperm(input.shape[1])                   # 2*32*32
                    input = input[:, perm]                                  # 5, 2*32*32, 4
                else:   # DeepONet interpolation
                    input = input.reshape(1,-1,4)                               # 1, 6*32*32, 4
                    input = input[:, self.inter_point]                  # 1, 6*128*128, 4
            else:
                if self.flag == 'valid': # Ours, DeepONet inter validation
                    ctx1 = input[:6].reshape(1,-1,4)                    # 1, 6*128*128, 7
                    ctx1 = ctx1[:, self.inter_point][:,:128*128]        # 1, 128*128, 7
                    ctx2 = input[5:6].reshape(1,-1,4)                   # 1, 128*128, 7
                    ctx = torch.concat([ctx1,ctx2], dim=0)              # 2, 128*128, 7
                    trg = input[6:8].reshape(2,-1,4)                    # 2, 128*128, 7
                    input = torch.concat([ctx,trg], dim=1)              # 2, 128*128+128*128, 7
                if self.flag == 'test' and self.extrapolation:       # Ours extrapolation
                    input = input.reshape(1,-1,4)                       # 1, 1(4)*128*128, 7
                elif self.flag == 'test' and not self.extrapolation: # Ours, DeepONet inter test
                    ctx   = input[:6].reshape(1,-1,4)                   # 1, 6*128*128, 7
                    ctx   = ctx[:, self.inter_point][:,:128*128]        # 1, 128*128, 7
                    trg   = input[6:].reshape(1,-1,4)                   # 1, 6*128*128, 7
                    input = torch.concat([ctx,trg], dim=1)              # 2, 128*128+4*128*128, 7
        return input

    def __len__(self):
        return self.N_train if self.system == "Seen dataset" else 1000-self.N_train


    
    
######################################## For famliy of 2D CDR equations ######################################## 
# Our own Genearation
class Dataset_2D_cdr(Dataset):
    def __init__(self, args, flag='train'):
        # init
        self.root_path = "./data_gen/dataset/2D_cdr"
        assert flag in ['train', 'test', 'valid']
        self.flag = flag
        if flag == 'test':
            assert args.system in ["Seen coeff", "Seen coeff given noisy", "Inter coeff", "Extra coeff"] # Seen / Unseen task in test
        self.system = args.system
        self.N_train = args.N_train_shot
        self.total = args.total      # Method in train
        
        self.beta_min = args.beta_min        ; self.beta_max = args.beta_max        ; self.beta_step = args.beta_step
        self.beta_y_min = args.beta_y_min    ; self.beta_y_max = args.beta_y_max    
        self.nu_min = args.nu_min            ; self.nu_max = args.nu_max            ; self.nu_step = args.nu_step
        self.nu_y_min = args.nu_y_min        ; self.nu_y_max = args.nu_y_max        
        self.rho_min = args.rho_min          ; self.rho_max = args.rho_max          ; self.rho_step = args.rho_step
        self.epsilon_min = args.epsilon_min  ; self.epsilon_max = args.epsilon_max  ; self.epsilon_step = args.epsilon_step
        self.theta_min = args.theta_min      ; self.theta_max = args.theta_max      ; self.theta_step = args.theta_step
        
        self.numerical = args.numerical
        self.extrapolation = args.extrapolation
        
        self.inter_point = np.load(f"{self.root_path}/inter_given_index.npy")
        
        self.__read_data__()
        
        if args.PINN_based_prior_ratio > 0:
            sample_size = int(len(self.specified_pde_list) * args.PINN_based_prior_ratio / 100)
            np.random.seed(args.seed)
            self.PINN_based_PDE_list = random.sample(self.specified_pde_list, sample_size)

    def __read_data__(self):
        self.specified_pde_list = []
        
        beta_interval = int((self.beta_max - self.beta_min) // self.beta_step)
        beta_y_interval = int((self.beta_y_max - self.beta_y_min) // self.beta_step)
        nu_interval = int((self.nu_max - self.nu_min) // self.nu_step)
        nu_y_interval = int((self.nu_y_max - self.nu_y_min) // self.nu_step)
        rho_interval = int((self.rho_max - self.rho_min) // self.rho_step)
        epsilon_interval = int((self.epsilon_max - self.epsilon_min) // self.epsilon_step)
        theta_interval = int((self.theta_max - self.theta_min) // self.theta_step)

        # Make a specified PDE list.
        for i in range(beta_interval + 1):
            for i_y in range(beta_y_interval + 1):
                for j in range(nu_interval + 1):
                    for j_y in range(nu_y_interval + 1):
                        for k in range(rho_interval + 1):
                            for l in range(epsilon_interval + 1):
                                for m in range(theta_interval + 1):
                                
                                    beta = self.beta_min + (i * self.beta_step)
                                    beta_y = self.beta_y_min + (i_y * self.beta_step)
                                    nu = self.nu_min + (j * self.nu_step)
                                    nu_y = self.nu_y_min + (j_y * self.nu_step)
                                    rho = self.rho_min + (k * self.rho_step)
                                    epsilon = self.epsilon_min + (l * self.epsilon_step)
                                    theta = self.theta_min + (m * self.theta_step)

                                    self.specified_pde_list.append((beta, beta_y, nu, nu_y, rho, epsilon, theta))
                                    
    def __getitem__(self, index):
        
        beta, beta_y, nu, nu_y, rho, epsilon, theta = self.specified_pde_list[index] 
        if self.numerical:
            str_analytical = "analytical" 
        else:
            str_analytical = "PINN_based" if self.specified_pde_list[index] in self.PINN_based_PDE_list else "analytical"
        
        file_name = f"{beta}_{beta_y}_{nu}_{nu_y}_{rho}_{epsilon}_{theta}_{str_analytical}.npy"
        file_name_analy = f"{beta}_{beta_y}_{nu}_{nu_y}_{rho}_{epsilon}_{theta}_analytical.npy"
        file_name_pinn = f"{beta}_{beta_y}_{nu}_{nu_y}_{rho}_{epsilon}_{theta}_PINN_based.npy"
        if not os.path.isfile(os.path.abspath(self.root_path+"/full/"+file_name_analy)):
            raise ValueError("There is no file.")
        # Total : [0, 20] 
        # Train : [0,2,4,6,8,10]
        #   - Total o                : random sample - rest of it
        #   - Total x, Extrapolation : given 0,2,4,6,8 - 2,4,6,8,10
        #   - Total x, Interpolation : given sampled - rest of it
        if self.flag == 'train':
            input = np.load(f"{self.root_path}/train/{file_name}")          # 6, 32, 32, 1
            input_domain = np.load(f"{self.root_path}/train_domain.npy")    # 6, 32, 32, 3
            input = np.concatenate((input_domain,input), axis=-1)           # 6, 32, 32, 4
            if self.total == 2:
                input = input.reshape(6,-1,4)                               # 6, 32*32, 4
                input = np.concatenate((input[:-1],input[1:]), axis=1)      # 5, 2*32*32, 4
            if self.total==1 and self.extrapolation:
                input = input.reshape(1,-1,4)                               # 1, 6*32*32, 4
                perm = torch.randperm(input.shape[1])                       # 2*32*32
                input = input[:, perm]                                      # 1, 6*32*32, 4
                input = input.reshape(6,-1,4)                               # 6, 32*32, 4
                input = np.concatenate((input[:-1],input[1:]), axis=1)      # 5, 2*32*32, 4
            elif self.total==1 and not self.extrapolation:
                input = input.reshape(1,-1,4)                               # 1, 6*32*32, 4
                input = input[:, self.inter_point]                          # 1, 6*32*32, 4
            return torch.tensor(input, dtype=torch.float32)
        
        # Valid 
        #   - Extrapolation : given 10 - 12
        #   - Interpolation : given sampled - 5
        elif self.flag == 'valid':
            if self.total==2:
                input = np.load(f"{self.root_path}/valid/extrapolation/{file_name_analy}")  # 2, 32, 32, 1
                input_domain = np.load(f"{self.root_path}/valid_extra_domain.npy")          # 2, 32, 32, 3
                input = np.concatenate((input_domain, input), axis=-1)                      # 2, 32, 32, 4
                input = input.reshape(1,-1,4)                                               # 1, 2*32*32, 4
            elif self.total==1:
                input1 = np.load(f"{self.root_path}/valid/interpolation/{file_name_analy}") # 1, 32*32+32*32, 1
                input_domain1 = np.load(f"{self.root_path}/valid_inter_domain.npy")         # 1, 32*32+32*32, 3
                input1 = np.concatenate((input_domain1, input1), axis=-1)                   # 1, 32*32+32*32, 3
                input2 = np.load(f"{self.root_path}/valid/extrapolation/{file_name_analy}") # 2, 32, 32, 1
                input_domain2 = np.load(f"{self.root_path}/valid_extra_domain.npy")         # 2, 32, 32, 3
                input2 = np.concatenate((input_domain2, input2), axis=-1)                   # 2, 32, 32, 4
                input2 = input2.reshape(1,-1,4)                                             # 1, 2*32*32, 4
                input = np.concatenate((input2,input1), axis=0)                             # 2, 2*32*32, 4
            else:
                input = np.load(f"{self.root_path}/valid/extrapolation/{file_name_analy}")  # 2, 32, 32, 1
                input_domain = np.load(f"{self.root_path}/valid_extra_domain.npy")          # 2, 32, 32, 3
                input = np.concatenate((input_domain, input), axis=-1)                      # 2, 32, 32, 4
                
            return torch.tensor(input, dtype=torch.float32)
        # Test
        #   - Extrapolation : given 12 - 14, 16, 18, 20
        #   - Interpolation : given sampled - 1, 3, 7, 9
        elif self.flag == 'test':
            if not self.extrapolation:
                if "given noisy" in self.system:
                    input_given = np.load(f"{self.root_path}/test/interpolation/{file_name_pinn}")  # 1, 32*32+4*32*32, 1
                    input = np.load(f"{self.root_path}/test/interpolation/{file_name_analy}")       # 1, 32*32+4*32*32, 1
                    input[:,:1024] = input_given[:,:1024]
                    input_domain = np.load(f"{self.root_path}/test_inter_domain.npy")               # 1, 32*32+4*32*32, 3
                    input = np.concatenate((input_domain,input), axis=-1)                           # 1, 32*32+4*32*32, 4
                else:
                    input = np.load(f"{self.root_path}/test/interpolation/{file_name_analy}")       # 1, 32*32+4*32*32, 1
                    input_domain = np.load(f"{self.root_path}/test_inter_domain.npy")               # 1, 32*32+4*32*32, 3
                    input = np.concatenate((input_domain,input), axis=-1)                           # 1, 32*32+4*32*32, 4
            else:
                if "given noisy" in self.system:
                    input_given = np.load(f"{self.root_path}/test/extrapolation/{file_name_pinn}")  # 5, 32, 32, 1
                    input = np.load(f"{self.root_path}/test/extrapolation/{file_name_analy}")       # 5, 32, 32, 1
                    input[0] = input_given[0]
                    input_domain = np.load(f"{self.root_path}/test_extra_domain.npy")               # 5, 32, 32, 3
                    input = np.concatenate((input_domain, input), axis=-1)                          # 5, 32, 32, 4
                else:
                    input = np.load(f"{self.root_path}/test/extrapolation/{file_name_analy}")       # 5, 32, 32, 1
                    input_domain = np.load(f"{self.root_path}/test_extra_domain.npy")               # 5, 32, 32, 3
                    input = np.concatenate((input_domain, input), axis=-1)                          # 5, 32, 32, 4
                if self.total:
                    input = input.reshape(1,-1,4)                                                   # 1, 5*32*32, 4
            return (beta, beta_y, nu, nu_y, rho, epsilon, theta), torch.tensor(input, dtype=torch.float32)

    def __len__(self):
        return len(self.specified_pde_list)
    
    
    
    

######################################## Data Factory ########################################

def identity_collate(batch):
    return batch[0]

DATA_DICT = {'2D_cdr'   : [Dataset_2D_cdr, 4, 1, [32,32]],
             'SWE'      : [Dataset_SWE, 4, 1, [128,128]],
             'CNSE'     : [Dataset_CNSE, 7, 4, [64,64]],
             'Darcy01'  : [Dataset_Darcy, 4, 1, [128,128]],
             }

def data_provider(args, flag):
    Data, n_in, n_out, grid_num = DATA_DICT[args.data]

    data_set = Data(args=args, flag=flag)
    
    if flag == 'train':
        data_loader = DataLoader(data_set, batch_size=args.batch, shuffle=True, 
                                 pin_memory=True, num_workers=0)
    elif flag == 'valid':
        data_loader = DataLoader(data_set, batch_size=args.batch, shuffle=False, 
                                 pin_memory=True, num_workers=0)
    else:
        data_loader = DataLoader(data_set, batch_size=1, shuffle=False, pin_memory=True, 
                                num_workers=0, collate_fn=identity_collate)
    return data_loader, n_in, n_out, grid_num