import torch
import gzip
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm
import scipy.io
import pickle


class CosmogridDataset(Dataset):
    def __init__(self, root_dir, split='train', data_size=-1, mask_suffix=None,
                 mask_transform=None):
        self.split = split
        self.root_dir = root_dir 
        self.data_size = data_size
        self.mask_suffix = mask_suffix
        self.mask_transform = mask_transform
        # load cosmological parameters -------
        # Omega_m, H0, ns, sigma_8, w, omega_b
        Xvals = np.load(os.path.join(root_dir, 'X_maps_Cosmogrid_100k.npy'), allow_pickle=True)
        Yvals = np.load(os.path.join(root_dir, 'y_maps_Cosmogrid_100k.npy'), allow_pickle=True)

        # number of samples
        num_samples = len(Yvals)

        # split the sample for training ----------
        train_split, val_split, test_split = int(0.80*num_samples), \
                    int(0.10*num_samples), int(0.10*num_samples)
            
        print('# samples used for training:', train_split)
        print('# samples used for validation:', val_split)
        print('# samples used for testing:' ,test_split)
        print('# total samples:', train_split+val_split+test_split)

        train_x, val_x, test_x = np.split(Xvals, [train_split, train_split+val_split])
        train_y, val_y, test_y = np.split(Yvals, [train_split, train_split+val_split])
        print('x shape', train_x.shape, val_x.shape, test_x.shape)
        print('y shape', train_y.shape, val_y.shape, test_y.shape)
        if mask_suffix is not None:
            masks_vals = np.load(os.path.join(root_dir, 'masks',
                                              f'X_maps_Cosmogrid_100k{mask_suffix}.npy'), 
                                 allow_pickle=True)
            train_masks, val_masks, test_masks = np.split(masks_vals, 
                                                          [train_split, 
                                                           train_split+val_split])
            print('masks shape', train_masks.shape, val_masks.shape, test_masks.shape)
        else:
            train_masks, val_masks, test_masks = None, None, None

        params_mask = np.array([True,False,False,True,False,False])
        self.output_num = len(params_mask[params_mask])

        # let's focus on omega_m and sigma_8 
        train_y, val_y, test_y = train_y[:,params_mask], val_y[:,params_mask], test_y[:,params_mask]

        self.splits = {
            'train': (train_x, train_y, train_masks),
            'val': (val_x, val_y, val_masks),
            'test': (test_x, test_y, test_masks)
        }

        self.images = self.splits[split][0][:, None, :, :]
        self.labels = self.splits[split][1]
        self.masks = self.splits[split][2]
        print('-- ALL --')
        print('max', self.images.max())
        print('min', self.images.min())
        # print('mean', self.images.mean().item())
        # import pdb
        # pdb.set_trace()
        # print('mode', torch.tensor(self.images).view(-1).mode().values.item())

        if data_size != -1:
            self.images = self.images[:data_size]
            self.labels = self.labels[:data_size]
            if self.masks is not None:
                self.masks = self.masks[:data_size]

        print(f'-- SPLIT {split} --')
        print('max', self.images.max().item())
        print('min', self.images.min().item())
        # print('mean', self.images.mean().item())
        # print('mode', self.images.view(-1).mode().values.item())

        print(f'Finished loading {len(self.labels)} {split} images ... ')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.masks is not None:
            mask_i = self.masks[idx]
            mask = torch.tensor(mask_i)
            mask = self.mask_transform(mask)
            return image, label, mask, mask_i
        else:
            return image, label

        
    