import torch
from torch.utils.data import DataLoader, Dataset
import os.path as osp


# normalization, pointwise gaussian
class UnitGaussianNormalizer(object):
    def __init__(self, x, eps=0.00001, time_last=True):
        super(UnitGaussianNormalizer, self).__init__()

        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T in 1D
        # x could be in shape of ntrain*w*l or ntrain*T*w*l or ntrain*w*l*T in 2D
        self.mean = torch.mean(x, 0)
        self.std = torch.std(x, 0)
        self.eps = eps
        self.time_last = time_last # if the time dimension is the last dim

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        # sample_idx is the spatial sampling mask
        if sample_idx is None:
            std = self.std + self.eps # n
            mean = self.mean
        else:
            if self.mean.ndim == sample_idx.ndim or self.time_last:
                std = self.std[sample_idx] + self.eps  # batch*n
                mean = self.mean[sample_idx]
            if self.mean.ndim > sample_idx.ndim and not self.time_last:
                    std = self.std[...,sample_idx] + self.eps # T*batch*n
                    mean = self.mean[...,sample_idx]
        # x is in shape of batch*(spatial discretization size) or T*batch*(spatial discretization size)
        x = (x * std) + mean
        return x

    def to(self, device):
        if torch.is_tensor(self.mean):
            self.mean = self.mean.to(device)
            self.std = self.std.to(device)
        else:
            self.mean = torch.from_numpy(self.mean).to(device)
            self.std = torch.from_numpy(self.std).to(device)
        return self

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()

# normalization, Gaussian
class GaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(GaussianNormalizer, self).__init__()

        self.mean = torch.mean(x)
        self.std = torch.std(x)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        x = (x * (self.std + self.eps)) + self.mean
        return x

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()



class ERA5Dataset:
    def __init__(self, data_path, raw_resolution=[512, 512, 80], 
                 sample_resolution=[512, 512, 20], eval_resolution=[512, 512, 20], 
                 in_t=1, out_t=1, duration_t=23, train_day=6, test_day=2,
                 train_batchsize=10, eval_batchsize=10, 
                 normalize=True, normalizer_type='PGN', prop='temp', sub=False,
                 **kwargs):

        print('Processing raw data from ', data_path)
        data = torch.load(data_path)
        
        train_x, train_y, normalizer = self.pre_process(data[:train_day], mode='train', 
                                            in_t=in_t, out_t=out_t, duration_t=duration_t,
                                            normalize=normalize, normalizer_type=normalizer_type)
        print(train_x.shape)
        print(train_y.shape)

        test_x, test_y = self.pre_process(data[-test_day:], mode='test', 
                                            in_t=in_t, out_t=out_t, duration_t=duration_t, 
                                            normalize=normalize, normalizer=normalizer)
        print(test_x.shape)
        print(test_y.shape)

        self.ntrain = train_x.shape[0]
        self.ntest = test_x.shape[0]
        
        if sub is not False:
            sub_index = int(len(train_x) * sub)
            train_x = train_x[:sub_index]
            train_y = train_y[:sub_index]
        
        self.train_dataset = ERA5Base(train_x, train_y, mode='train', prop=prop, raw_resolution=raw_resolution, sample_resolution=sample_resolution)

        self.test_dataset = ERA5Base(test_x, test_y, mode='test', prop=prop, raw_resolution=raw_resolution, sample_resolution=eval_resolution, )
                
        self.train_loader = DataLoader(self.train_dataset, batch_size=train_batchsize, shuffle=True)

        self.test_loader = DataLoader(self.test_dataset, batch_size=eval_batchsize, shuffle=False)

        self.y_normalizer = normalizer
    
    def pre_process(self, data, in_t, out_t, duration_t, mode='train', 
                    normalize=False, normalizer_type='PGN', normalizer=None, **kwargs):
        
        if mode == 'train':
            x = data[:, :in_t, :, :, :]
            y = data[:, in_t:in_t+1, :, :, :]
            for i in range(1, duration_t):
                x = torch.cat((x, data[:, i:in_t+i, :, :, :]), dim=0)
                y = torch.cat((y, data[:, in_t+i:in_t+i+1, :, :, :]), dim=0)
        else:
            x = data[:, out_t-in_t:out_t, :, :, :]
            y = data[:, out_t:out_t+1, :, :, :]
            for i in range(1, duration_t):
                x = torch.cat((x, data[:, out_t+i-in_t:out_t+i, :, :, :]), dim=0)
                y = torch.cat((y, data[:, out_t+i:out_t+i+1, :, :, :]), dim=0)
        
        if normalize:
            if mode == 'train':
                if normalizer_type == 'PGN':
                    x_normalizer = UnitGaussianNormalizer(x)
                    y_normalizer = UnitGaussianNormalizer(y)
                else:
                    x_normalizer = GaussianNormalizer(x)
                    y_normalizer = GaussianNormalizer(y)
                x = x_normalizer.encode(x)
                y = y_normalizer.encode(y)
            else:
                x = normalizer.encode(x)
        else:
            x_normalizer = None
            y_normalizer = None

        x = x.squeeze(1)
        y = y.squeeze(1)
        
        grid_x = torch.linspace(-90, 90, x.shape[1])
        grid_x = grid_x.reshape(1, x.shape[1], 1, 1).repeat(x.shape[0], 1, x.shape[2], 1)
        grid_y = torch.linspace(-180, 180, x.shape[2])
        grid_y = grid_y.reshape(1, 1, x.shape[2], 1).repeat(x.shape[0], x.shape[1], 1, 1)
                
        x = torch.cat([x, grid_x, grid_y], dim=-1)
        
        if mode == 'train':
            return x, y, y_normalizer
        else:
            return x, y


class ERA5Base(Dataset):
    """
    A base class for the Navier-Stokes dataset.

    Args:
        x (list): The input data.
        y (list): The target data.
        mode (str, optional): The mode of the dataset. Defaults to 'train'.
        **kwargs: Additional keyword arguments.

    Attributes:
        mode (str): The mode of the dataset.
        x (list): The input data.
        y (list): The target data.
    """

    def __init__(self, x, y, mode='train', prop='wind_u', raw_resolution=[512, 512, 20], sample_resolution=[512, 512, 20], **kwargs):
        self.mode = mode
        sample_factor_0 = raw_resolution[0] // sample_resolution[0]
        sample_factor_1 = raw_resolution[1] // sample_resolution[1]
        
        grid = x[:, ::sample_factor_0, ::sample_factor_1, -2:]
        if prop == 'wind_u':
            self.x = x[:, ::sample_factor_0, ::sample_factor_1, 1:3]
            self.x = torch.cat([self.x, grid], dim=-1)
            self.y = y[:, ::sample_factor_0, ::sample_factor_1, 1:2]
        elif prop == 'wind_v':
            self.x = x[:, ::sample_factor_0, ::sample_factor_1, 1:3]
            self.x = torch.cat([self.x, grid], dim=-1)
            self.y = y[:, ::sample_factor_0, ::sample_factor_1, 2:3]
        else:
            raise ValueError('Invalid property')
            
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
