import torch
from torch.utils.data import DataLoader, Dataset
import os.path as osp

from utils.normalizer import UnitGaussianNormalizer, GaussianNormalizer


class CarraDataset:
    def __init__(self, data_path, sample_factor=[1, 1],
                 train_ratio=0.6, valid_ratio=0.2, test_ratio=0.2,
                 train_batchsize=10, eval_batchsize=10, 
                 normalize=True, normalizer_type='PGN', 
                 prop='v10', sub=False,
                 **kwargs):
        self.__file__ = osp.abspath(__file__)
        process_path = data_path.split('.')[0] + '_' + prop + '_processed.pt'
        if osp.exists(process_path):
            print('Loading processed data from ', process_path)
            (train_x, train_y), (valid_x, valid_y), (test_x, test_y), x_normalizer, y_normalizer = torch.load(process_path)
        else:
            print('Processing raw data from ', data_path)
            data, lat, lon = torch.load(data_path)
            
            if prop == 'v10':
                data = data[0, ...]
            elif prop == 'sp':
                data = data[1, ...]
            else:
                raise ValueError("Invalid property type.")
            
            train_idx = int(len(data) * train_ratio)
            valid_idx = int(len(data) * (train_ratio + valid_ratio))

            train_x, train_y, x_normalizer, y_normalizer = self.pre_process(data[:train_idx], lat=lat, lon=lon, mode='train', prop=prop, 
                                                                            normalize=normalize, normalizer_type=normalizer_type)
            valid_x, valid_y = self.pre_process(data[train_idx:valid_idx], lat=lat, lon=lon, mode='valid', prop=prop,
                                                normalize=normalize, x_normalizer=x_normalizer, y_normalizer=y_normalizer, )
            test_x, test_y = self.pre_process(data[valid_idx:], lat=lat, lon=lon, mode='test', prop=prop,
                                              normalize=normalize, x_normalizer=x_normalizer, y_normalizer=y_normalizer, )
            
            torch.save(((train_x, train_y), (valid_x, valid_y), (test_x, test_y), x_normalizer, y_normalizer), process_path)

        if sub is not False:
            sub_index = int(len(train_x) * sub)
            train_x = train_x[:sub_index]
            train_y = train_y[:sub_index]
        
        self.train_dataset = CarraBase(train_x, train_y, mode='train', sample_factor=sample_factor,
                                      x_normalizer=x_normalizer, y_normalizer=y_normalizer)
        self.valid_dataset = CarraBase(valid_x, valid_y, mode='valid', sample_factor=sample_factor,
                                      x_normalizer=x_normalizer, y_normalizer=y_normalizer)
        self.test_dataset = CarraBase(test_x, test_y, mode='test', sample_factor=sample_factor,
                                     x_normalizer=x_normalizer, y_normalizer=y_normalizer)
                
        self.train_loader = DataLoader(self.train_dataset, batch_size=train_batchsize, shuffle=True)
        self.valid_loader = DataLoader(self.valid_dataset, batch_size=eval_batchsize, shuffle=False)
        self.test_loader = DataLoader(self.test_dataset, batch_size=eval_batchsize, shuffle=False)
    
    def pre_process(self, data, lat, lon, mode='train', prop='temp', normalize=False, normalizer_type='PGN', 
                    x_normalizer=None, y_normalizer=None, **kwargs):
        x = data[:-1, ...]
        y = data[1:, ...]
        
        x = x.unsqueeze(-1)
        y = y.unsqueeze(-1)

        B, H, W, C = x.shape
        
        lat = lat.view(1, H, W, 1)
        lon = lon.view(1, H, W, 1)
        
        if normalize:
            x = x.view(B, -1, C)
            y = y.view(B, -1, C)
            lat = lat.view(1, -1, 1)
            lon = lon.view(1, -1, 1)
            if mode == 'train':
                if normalizer_type == 'PGN':
                    x_normalizer = UnitGaussianNormalizer(x)
                    y_normalizer = UnitGaussianNormalizer(y)
                else:
                    x_normalizer = GaussianNormalizer(x)
                    y_normalizer = GaussianNormalizer(y)
                x = x_normalizer.encode(x)
                y = y_normalizer.encode(y)
            else:
                x = x_normalizer.encode(x)
                y = y_normalizer.encode(y)
            x = x.view(B, H, W, C)
            y = y.view(B, H, W, C)
            
            lat = (lat - lat.mean()) / lat.std()
            lon = (lon - lon.mean()) / lon.std()
            lat = lat.view(1, H, W, 1)
            lon = lon.view(1, H, W, 1)
        else:
            x_normalizer = None
            y_normalizer = None
        
        lat = lat.repeat(B, 1, 1, 1)
        lon = lon.repeat(B, 1, 1, 1)

        x = torch.cat([lat, lon, x], dim=-1)
        
        if mode == 'train':
            return x, y, x_normalizer, y_normalizer
        else:
            return x, y


class CarraBase(Dataset):
    """
    A base class for the Navier-Stokes dataset.

    Args:
        x (list): The input data.
        y (list): The target data.
        mode (str, optional): The mode of the dataset. Defaults to 'train'.
        **kwargs: Additional keyword arguments.

    Attributes:
        mode (str): The mode of the dataset.
        x (list): The input data.
        y (list): The target data.
    """

    def __init__(self, x, y, mode='train', 
                 sample_factor=[1, 1], 
                 x_normalizer=None, y_normalizer=None, 
                 **kwargs):
        self.mode = mode
        self.x_normalizer = x_normalizer
        self.y_normalizer = y_normalizer

        self.x = x[:, ::sample_factor[0], ::sample_factor[1], :]
        self.y = y[:, ::sample_factor[0], ::sample_factor[1], :]

        self.x = self.x.view(self.x.shape[0], -1, self.x.shape[-1])
        self.y = self.y.view(self.y.shape[0], -1, self.y.shape[-1])
        
        self.x = self.x.float()
        self.y = self.y.float()

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
