import torch
from torch.utils.data import Dataset
import h5py
import pickle
import numpy as np
from utils import *



class AirfoilDataset(Dataset):
    def __init__(self, input_data, output_data, n_train, n_val=None, n_test=None,
                 is_train=True, is_val=False, x_normalizer=None, y_normalizer=None, norm_type='unit_gaussian'):

        s1 = int(((221 - 1) / 1) + 1)
        s2 = int(((51 - 1) / 1) + 1)

        if is_train:
            x_data = torch.tensor(input_data[:n_train, :s1, :s2], dtype=torch.float)
            y_data = torch.tensor(output_data[:n_train, :s1, :s2], dtype=torch.float)
            self.idx_lst = range(n_train)
        elif is_val:
            x_data = torch.tensor(input_data[-n_val:, :s1, :s2], dtype=torch.float)
            y_data = torch.tensor(output_data[-n_val:, :s1, :s2], dtype=torch.float)
            self.idx_lst = range(n_val)
        else:   
            x_data = torch.tensor(input_data[n_train:n_train+n_test, :s1, :s2], dtype=torch.float)
            y_data = torch.tensor(output_data[n_train:n_train+n_test, :s1, :s2], dtype=torch.float)
            self.idx_lst = range(n_test)

        if x_normalizer is None:
            if norm_type == 'unit_gaussian':
                self.x_normalizer = UnitGaussianNormalizer(x_data)
            elif norm_type == 'gaussian':
                self.x_normalizer = GaussianNormalizer(x_data)
        else:
            self.x_normalizer = x_normalizer
        if y_normalizer is None:
            if norm_type == 'unit_gaussian':
                self.y_normalizer = UnitGaussianNormalizer(y_data)
            elif norm_type == 'gaussian':
                self.y_normalizer = GaussianNormalizer(y_data)
        else:
            self.y_normalizer = y_normalizer
        
        self.input_data = self.x_normalizer.encode(x_data)
        self.output_data = self.y_normalizer.encode(y_data).unsqueeze(-1)
        
        print(f"Dataset shape - Input: {self.input_data.shape}, Output: {self.output_data.shape}")


    def __len__(self):
        return self.input_data.shape[0]

    def __getitem__(self, idx):
        return self.input_data[idx], self.output_data[idx], self.idx_lst[idx]
    
    def get_normalizers(self):
        return self.x_normalizer, self.y_normalizer
    
    def get_vis_data(self):
        return self.input_data[:10], self.output_data[:10]


class PipeDataset(Dataset):
    def __init__(self, input_data, output_data, n_train, n_val=None, n_test=None, is_train=True, is_val=False,
                 x_normalizer=None, y_normalizer=None, norm_type='unit_gaussian'):
        
        if is_train:
            self.input_data = torch.tensor(input_data[:n_train], dtype=torch.float)
            self.output_data = torch.tensor(output_data[:n_train], dtype=torch.float)
            self.idx_lst = range(n_train)
        elif is_val:
            self.input_data = torch.tensor(input_data[-n_val:], dtype=torch.float)
            self.output_data = torch.tensor(output_data[-n_val:], dtype=torch.float)
            self.idx_lst = range(n_val)
        else:
            self.input_data = torch.tensor(input_data[n_train:n_train + n_test], dtype=torch.float)
            self.output_data = torch.tensor(output_data[n_train:n_train + n_test], dtype=torch.float)
            self.idx_lst = range(n_test)
        

        if norm_type is None:
            self.output_data = self.output_data.unsqueeze(-1)

        else:
            if x_normalizer is None:
                if norm_type == 'unit_gaussian':
                    self.x_normalizer = UnitGaussianNormalizer(self.input_data)
                    self.y_normalizer = UnitGaussianNormalizer(self.output_data)
            else:
                self.x_normalizer = x_normalizer
                self.y_normalizer = y_normalizer
                
            self.input_data = self.x_normalizer.encode(self.input_data)
            self.output_data = self.y_normalizer.encode(self.output_data).unsqueeze(-1)
        

        print(f"Dataset shape - Input: {self.input_data.shape}, Output: {self.output_data.shape}")

    def __len__(self):
        return self.input_data.shape[0]

    def __getitem__(self, idx):
        return self.input_data[idx], self.output_data[idx], self.idx_lst[idx]
    
    def get_vis_data(self):
        return self.input_data[:10], self.output_data[:10]
    
    def get_normalizers(self):
        return self.x_normalizer, self.y_normalizer
    
    
class FWIDataset(Dataset):
    def __init__(self, seis, vel, n_train):
        
        self.seis_max, self.seis_min = np.max(seis), np.min(seis)
        log_seis_max, log_seis_min = log_transform(self.seis_max, k=1), log_transform(self.seis_min, k=1)
        self.vel_max, self.vel_min = np.max(vel), np.min(vel)
        
        self.seis = torch.tensor(seis[:n_train], dtype=torch.float).permute(0,2,3,1)
        self.vel = torch.tensor(vel[:n_train], dtype=torch.float).permute(0,2,3,1)
        
        self.seis = LogTransform(k=1)(self.seis)
        self.seis = MinMaxNormalize(log_seis_min, log_seis_max)(self.seis)
        self.vel = MinMaxNormalize(self.vel_min, self.vel_max)(self.vel)
        
        self.idx_lst = range(n_train)
        
    def __len__(self):
        return self.seis.shape[0]
    
    def __getitem__(self, idx):
        return self.seis[idx], self.vel[idx], self.idx_lst[idx]
    
    def get_vis_data(self):
        return self.seis[:10], self.vel[:10]
    
    def get_minmax(self):
        return self.seis_min, self.seis_max, self.vel_min, self.vel_max
    

        
        

class NavierStokesDataset3D_twoway(Dataset):
    def __init__(self, datapath, t1, t2, t_interval, seq_len1_obs, seq_len2_obs, seq_len3_obs,
                 seq_len1_src, seq_len2_src, seq_len3_src, out_dim_obs, out_dim_src, 
                 n_train=None, x_normalizer=None, y_normalizer=None):
        
        self.n_train = n_train
        sub = 1
        data = h5py.File(datapath)['u']
        
        self.u_obs = torch.tensor(data[t1:t1+t_interval, ::sub, ::sub, :n_train], 
                                dtype=torch.float).transpose(0, 3)
        self.u_src = torch.tensor(data[t2:t2+t_interval, ::sub, ::sub, :n_train], 
                                    dtype=torch.float).transpose(0, 3)
        
        self.u_obs = self.u_obs.reshape(n_train, seq_len1_obs, seq_len2_obs, seq_len3_obs, out_dim_obs)
        self.u_src = self.u_src.reshape(n_train, seq_len1_src, seq_len2_src, seq_len3_src, out_dim_src)
        self.idx_lst = range(n_train)

        if x_normalizer is None:
                self.x_normalizer = UnitGaussianNormalizer(self.u_obs)
                self.y_normalizer = UnitGaussianNormalizer(self.u_src)
        else:
            self.x_normalizer = x_normalizer
            self.y_normalizer = y_normalizer
                
        self.u_obs = self.x_normalizer.encode(self.u_obs)
        self.u_src = self.y_normalizer.encode(self.u_src)

        print(f"Dataset shape - Input: {self.u_obs.shape}, Output: {self.u_src.shape}")

    def __len__(self):
        return len(self.idx_lst)
    
    def __getitem__(self, idx):
        return self.u_obs[idx], self.u_src[idx], self.idx_lst[idx]
    
    def get_normalizers(self):
        return self.x_normalizer, self.y_normalizer
    
    def get_vis_data(self):
        return self.u_obs[0:1], self.u_src[0:1]


        
        



def init_dataset(args, is_test=False):

    dataset_name = args.dataset
    datasets = {
        'airfoil': AirfoilDataset,
        'pipe': PipeDataset,
        'fwi': FWIDataset,
        'ns3d_twoway': NavierStokesDataset3D_twoway
    }
    
    if dataset_name not in datasets:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    dataset_class = datasets[dataset_name]


    if args.dataset == 'airfoil':
        input1 = np.load(args.input1_path)
        input2 = np.load(args.input2_path)
        input_data = np.stack([input1, input2], axis=-1)
        output_data = np.load(args.output_path)[:, 4]
        
        
        dataset_configs = {
            'train': {
                'input_data': input_data,
                'output_data': output_data,
                'n_train': args.ntrain,
                'is_train': True,
                'is_val': False,
                'norm_type': args.norm
            },
            'valid': {
                'input_data': input_data,
                'output_data': output_data,
                'n_train': args.ntrain,
                'n_val': args.nval,
                'is_train': False,
                'is_val': True,
                'x_normalizer': None,
                'y_normalizer': None
            },
            'test': {
                'input_data': input_data,
                'output_data': output_data,
                'n_train': args.ntrain,
                'n_test': args.ntest,
                'is_train': False,
                'is_val': False
            }   
        }
    elif args.dataset == 'pipe':
        input1 = np.load(args.input1_path)
        input2 = np.load(args.input2_path)
        input_data = np.stack([input1, input2], axis=-1)  # input.shape=[2310, 129, 129, 2]
        output_data = np.load(args.output_path)[:, 0]
        
        dataset_configs = {
            'train': {
                'input_data': input_data,
                'output_data': output_data,
                'n_train': args.ntrain,
                'is_train': True,
                'is_val': False,
                'norm_type': args.norm
            },
            'valid': {
                'input_data': input_data,
                'output_data': output_data,
                'n_train': args.ntrain,
                'n_val': args.nval,
                'is_train': False,
                'is_val': True,
                'norm_type': args.norm,
                'x_normalizer': None,
                'y_normalizer': None
            },
            'test': {
                'input_data': input_data,
                'output_data': output_data,
                'n_train': args.ntrain,
                'n_test': args.ntest,
                'is_train': False,
                'is_val': False,
                'norm_type': args.norm,
                'x_normalizer': None,
                'y_normalizer': None
            }
        }
    elif args.dataset == 'fwi':
        seis = np.load('../data/fwi/seis6_1_35.npy')
        vel = np.load('../data/fwi/vel6_1_35.npy')

        dataset_configs = {
            'train': {
                'seis': seis,
                'vel': vel,
                'n_train': args.ntrain,
            }
        }
        
    elif args.dataset == 'ns3d_twoway':
        dataset_configs = {
            'train': {
                'datapath': args.data_path,
                't1': args.t1,
                't2': args.t2,
                't_interval': args.t_interval,
                'seq_len1_obs': args.seq_len1_obs,
                'seq_len2_obs': args.seq_len2_obs,
                'seq_len3_obs': args.seq_len3_obs,
                'seq_len1_src': args.seq_len1_src,
                'seq_len2_src': args.seq_len2_src,
                'seq_len3_src': args.seq_len3_src,
                'out_dim_obs': args.out_dim_obs,
                'out_dim_src': args.out_dim_src,
                'n_train': args.ntrain,
                'x_normalizer': None,
                'y_normalizer': None
            }
        }
    else:
        raise ValueError(f"Unknown dataset: {args.dataset}")
    
    train_dataset = dataset_class(**dataset_configs['train'])
    
    if args.normalize:
        if args.dataset != 'ns3d_twoway':
            x_normalizer, y_normalizer = train_dataset.get_normalizers()
            dataset_configs['valid']['x_normalizer'] = x_normalizer
            dataset_configs['valid']['y_normalizer'] = y_normalizer
            dataset_configs['test']['x_normalizer'] = x_normalizer
            dataset_configs['test']['y_normalizer'] = y_normalizer 
    
    if not is_test:
        valid_dataset = dataset_class(**dataset_configs['valid']) if 'valid' in dataset_configs else None
    else:
        valid_dataset = dataset_class(**dataset_configs['test']) if 'test' in dataset_configs else None

    return train_dataset, valid_dataset