import torch
from torch.utils.data import Dataset
import h5py
import pickle
import numpy as np
import h5py
from utils import UnitGaussianNormalizer

class ConvectionDataset(Dataset):
    def __init__(self, data_path, beta_start, beta_end, seq_len1, seq_len2, out_dim, is_train=True, is_test=False, mask_pct=None):
        with open(data_path, 'rb') as f:
            conv_data = pickle.load(f)
        self.is_test = is_test
        self.u = torch.tensor(conv_data[beta_start-1:beta_end], dtype=torch.float).reshape(-1, seq_len1, seq_len2, out_dim)
        self.idx_lst = range(self.u.shape[0])
        
        if is_test:
            self.u_wo_mask = self.u.clone()
            self.u = self.u[:, :int(seq_len1 * mask_pct), :, :]
        
        print(f"Dataset shape - Input: {self.u.shape}")

    def __len__(self):
        return self.u.shape[0]

    def __getitem__(self, idx):
        if self.is_test:
            return self.u[idx], self.idx_lst[idx], self.u_wo_mask[idx]
        else:
            return self.u[idx], self.idx_lst[idx]

    def get_vis_data(self):
        return self.u[-10:], self.idx_lst[-10:]
    
class HelmholtzDataset(Dataset):
    def __init__(self, a_start, a_end, seq_len1, seq_len2, out_dim, is_train=True, is_val=False):
        with open(f'./datasets/helmholtz/helmholtz_data{a_start}to{a_end}.pkl', 'rb') as f:
            helm_data = pickle.load(f)
                
        self.u = helm_data.reshape(-1, seq_len1, seq_len2, out_dim)
        self.idx_lst = range(self.u.shape[0])

        print(f"Dataset shape - Input: {self.u.shape}")

    def __len__(self):
        return self.u.shape[0]

    def __getitem__(self, idx):
        return self.u[idx], self.idx_lst[idx]

    def get_vis_data(self):
        return self.u[-10:], self.idx_lst[-10:]
    


class NavierStokesDataset3D(Dataset):
    def __init__(self, datapath, t1, t2, seq_len1, seq_len2, seq_len3, out_dim, n_train=None, x_normalizer=None):
        self.n_train = n_train
        
        sub = 1
        t2 = t2 // sub
        seq_len3 = t2 - t1 + 1
        
        data = h5py.File(datapath)['u']
        
        self.u = torch.tensor(data[::sub, :, :, :n_train], dtype=torch.float)
        self.u = self.u[t1:t2+1].transpose(0, 3)
        
        self.u = self.u.reshape(-1, seq_len1, seq_len2, seq_len3, out_dim)
        self.idx_lst = range(n_train)

        if x_normalizer is None:
            self.x_normalizer = UnitGaussianNormalizer(self.u)
        else:
            self.x_normalizer = x_normalizer
            
        self.u = self.x_normalizer.encode(self.u)

        print(f"Dataset shape - Input: {self.u.shape}")

    def __len__(self):
        return len(self.idx_lst)
    
    def __getitem__(self, idx):
        return self.u[idx], self.idx_lst[idx]
    
    def get_normalizer(self):
        return self.x_normalizer
    
    def get_vis_data(self):
        return self.u[:1], self.idx_lst[:1]
    



class KuramotoSivashinskyDataset(Dataset):
    def __init__(self, data_path, seq_len1, seq_len2, out_dim, n_train=None, n_test=None, 
                 normalize=False, x_normalizer=None, mask_pct=None, is_test=False):
        with h5py.File(data_path, "r") as f: 
            g = f['train']
            dset_name = next(k for k in g.keys() if k.startswith("pde_"))
            data_u = np.array(g[dset_name])
            if not is_test:
                self.u = torch.tensor(data_u, dtype=torch.float)[:n_train].reshape(-1, seq_len1, seq_len2, out_dim)
            else:
                self.u = torch.tensor(data_u, dtype=torch.float)[n_train:n_train+n_test].reshape(-1, seq_len1, seq_len2, out_dim)

        if not is_test:
            self.idx_lst = range(n_train)
        else:
            self.idx_lst = range(n_test)
            
        if normalize:
            if x_normalizer is None:
                self.x_normalizer = UnitGaussianNormalizer(self.u)
            else:
                self.x_normalizer = x_normalizer
            
            self.u = self.x_normalizer.encode(self.u)
        
        if is_test and mask_pct is not None:
            self.u_wo_mask = self.u.clone()
            self.u = self.u[:, :int(seq_len1 * mask_pct), :, :]
        
        print(f"Dataset shape - Input: {self.u.shape}")

    def __len__(self):
        return len(self.idx_lst)
    
    def __getitem__(self, idx):
        return self.u[idx], self.idx_lst[idx]
    
    def get_vis_data(self):
        return self.u[:10], self.idx_lst[:10]
    
    def get_true_data(self):
        return self.u
    
    def get_normalizer(self):
        return self.x_normalizer
    


def init_dataset(args, is_test=False, normalizer=None):

    dataset_name = args.dataset
    datasets = {
        'conv': ConvectionDataset,
        'helm': HelmholtzDataset,
        'ns3d': NavierStokesDataset3D,
        'ks': KuramotoSivashinskyDataset
    }
    
    if dataset_name not in datasets:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    dataset_class = datasets[dataset_name]
    
    if args.dataset == 'conv':
        dataset_configs = {
            'train': {
                'data_path': args.data_path,
                'beta_start': args.beta_start,
                'beta_end': args.beta_end,
                'seq_len1': args.seq_len1,
                'seq_len2': args.seq_len2,
                'out_dim': args.out_dim,
                'is_train': True,
                'is_test': False,
            },
            'test': {
                'data_path': args.data_path,
                'beta_start': args.beta_start,
                'beta_end': args.beta_end,
                'seq_len1': args.seq_len1,
                'seq_len2': args.seq_len2,
                'out_dim': args.out_dim,
                'is_train': False,
                'is_test': True,
                'mask_pct': args.mask_pct,
            }
        }
    elif args.dataset == 'helm':
        dataset_configs = {
            'train': {
                'a_start': args.a_start,
                'a_end': args.a_end,
                'seq_len1': args.seq_len1,
                'seq_len2': args.seq_len2,
                'out_dim': args.out_dim,
                'is_train': True,
                'is_val': False,
            },
            'test': {
                'a_start': args.a_start,
                'a_end': args.a_end,
                'seq_len1': args.seq_len1,
                'seq_len2': args.seq_len2,
                'out_dim': args.out_dim,
                'is_train': False,
                'is_val': False,
                'mask_pct': args.mask_pct
            }
        }
    elif args.dataset == 'ns3d':
        dataset_configs = {
            'train': {
                'datapath': args.data_path,
                't1': args.t1,
                't2': args.t2,
                'seq_len1': args.seq_len1,
                'seq_len2': args.seq_len2,
                'seq_len3': args.seq_len3,
                'n_train': args.ntrain,
                'out_dim': args.out_dim,
                'x_normalizer': None,
            }
        }
        
    elif args.dataset == 'ks':
        dataset_configs = {
            'train': {
                'data_path': args.data_path,
                'seq_len1': args.seq_len1,
                'seq_len2': args.seq_len2,
                'out_dim': args.out_dim,
                'n_train': args.ntrain,
                'normalize': args.normalize,
                'x_normalizer': None
            },
            'test': {
                'data_path': args.data_path,
                'seq_len1': args.seq_len1,
                'seq_len2': args.seq_len2,
                'out_dim': args.out_dim,
                'n_train': args.ntrain,
                'n_test': args.ntest,
                'normalize': args.normalize,
                'x_normalizer': normalizer,
                'mask_pct': args.mask_pct,
                'is_test': True
            }
        }
    
    else:
        raise ValueError(f"Unknown dataset: {args.dataset}")
    if is_test:
        dataset = dataset_class(**dataset_configs['test'])
    else:
        dataset = dataset_class(**dataset_configs['train'])

    return dataset