'''
European Centre for Medium-Range Weather Forecasts Reanalysis v5 (ERA5) dataset.
ERA5 hourly data on single levels from 1940 to present.

https://cds.climate.copernicus.eu/datasets
'''
import os
import sys
sys.path.append('.')
path = os.path.dirname(sys.argv[0])

from functools import partial

from cv2 import resize
import cdsapi
import numpy as np
import torch
import xarray
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

from torch.utils.data import Dataset, random_split


class ERA5(Dataset):
    '''
    The ERA5 dataset for heteroscedastic regression.
    
    Original Dataset:
    
    - ERA5 hourly data on single levels from 1940 to present
    - Variable: 2m temperature
    - Area: 83, -169, 7, -35
    - Year: 2020, 2021, 2022, 2023, 2024
    - Month: 01-12
    - Day: 01, 05, 10, 15, 20, 25
    - Time: 00:00, 06:00, 12:00, 18:00
    - Raw data shape: (time, latitude, longitude) = (n_year*n_month*n_day*n_hour, 305, 537)
    - Resized data shape: (time, latitude, longitude) = (time, image_size, image_size)
        
    Heteroscedastic Dataset:
    
    - Input: time (month), latitude, longitude
    - Output: 2m temperature
    - Unobservable variable: time (day and hour), which brings noise to the output
        
    Parameters
    ----------
    image_size : int
        The size of the image.
        
    split : str
        The split of the dataset.

    download : bool
        Whether to download the dataset.
        
    gpu_id : int
        GPU ID. If None or a negative integer, use CPU.
        
    Attributes
    ----------    
    num_samples : int
        Number of samples in the dataset.

    t : torch.Tensor (num_samples, 1)
        Time data.
        
    x : torch.Tensor (num_samples, image_size, image_size)
        Temperature data.
    '''

    def __init__(self, image_size: int, split: str, max_samples: int = None,
                download: bool = False, gpu_id: int = None) -> None:
        
        self.fname_data = os.path.join(path, 'data', 'era5_t2m.grib')
        self.image_size = image_size
        self.split = split
        self.max_samples = max_samples
        self.gpu_id = gpu_id
        
        if download and not os.path.exists(self.fname_data):
            
            print("Downloading ERA5...")
            dataset_name = "reanalysis-era5-single-levels"
            request = {
                "product_type": ["reanalysis"],
                "variable": ["2m_temperature"],
                "year": [2020, 2021, 2022, 2023, 2024],
                "month": ["01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12"],
                "day": [1, 5, 10, 15, 20, 25],
                "time": ["00:00", "06:00", "12:00", "18:00"],
                "data_format": "grib",
                "area": [83, -169, 7, -35] # North, West, South, East
            }
            client = cdsapi.Client()
            client.retrieve(dataset_name, request, self.fname_data)
            
        elif not os.path.exists(self.fname_data):
            raise FileNotFoundError(f"ERA5 data file {self.fname_data} not found!")
        
        self.n_year = 5
        self.n_month = 12
        self.n_day = 6
        self.n_hour = 4
        
        print("Loading temperature dataset (time, latitude, longitude) -- resize --> shape [n_raw_samples, image_size, image_size]")
        dataset = xarray.open_dataset(self.fname_data)["t2m"].values # (time, latitude, longitude) = (288, 305, 537)
        print(f"Raw dataset shape: {dataset.shape}")
        self.num_raw_samples = dataset.shape[0]

        # Pre-processing (resize and normalize)
        resize_func = partial(resize, dsize=(image_size, image_size))
        dataset = np.array(list(map(resize_func, dataset)))
        self.dataset = normalize_range(dataset, low=-1, high=1)

        # Create the dataset for heteroscedastic regression
        self._assemble_heteroscedastic_dataset()
                
        # Split the dataset into train and test
        train_split, test_split = random_split(range(len(self.x_cpu)), [0.8, 0.2])
        
        if self.max_samples is not None:
            # Convert subset indices to list, slice, then create new subsets
            train_indices = train_split.indices[:int(self.max_samples*0.8)]
            test_indices = test_split.indices[:int(self.max_samples*0.2)]
            train_split = torch.utils.data.Subset(range(len(self.x_cpu)), train_indices)
            test_split = torch.utils.data.Subset(range(len(self.x_cpu)), test_indices)
        
        if split == "train":
            self.x_cpu = self.x_cpu[train_split]
            self.y_cpu = self.y_cpu[train_split]
        elif split == "test":
            self.x_cpu = self.x_cpu[test_split]
            self.y_cpu = self.y_cpu[test_split]
        else:
            raise ValueError(f"Invalid split {split} provided.")

        self.x = torch.tensor(self.x_cpu).float()
        self.y = torch.tensor(self.y_cpu).float()
        
        if torch.cuda.is_available() and self.gpu_id is not None:
            self.x = self.x.to(self.gpu_id)
            self.y = self.y.to(self.gpu_id)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    def _assemble_heteroscedastic_dataset(self) -> None:
        '''
        Assemble the heteroscedastic dataset.
        '''
        # Normalized time coordinates
        time_expanded = np.zeros(self.num_raw_samples*self.image_size**2)
        for i in range(self.n_year*self.n_month): # 5 years, 12 months
            nn = self.n_day*self.n_hour*self.image_size**2
            time_expanded[i*nn:(i+1)*nn] = i*1.0 # 6 days, 4 hours
        time_expanded = time_expanded / (self.n_year*self.n_month - 1)

        # Create coordinate grids for latitude and longitude
        lat_coords = np.linspace(0, 1, self.image_size)  # Normalized latitude coordinates
        lon_coords = np.linspace(0, 1, self.image_size)  # Normalized longitude coordinates
        lat_grid, lon_grid = np.meshgrid(lat_coords, lon_coords, indexing='ij')
        
        # Repeat spatial coordinates for each time sample
        lat_expanded = np.tile(lat_grid.flatten(), self.num_raw_samples)
        lon_expanded = np.tile(lon_grid.flatten(), self.num_raw_samples)
        
        # Stack to create (time, latitude, longitude) features
        self.x_cpu = np.column_stack([time_expanded, lat_expanded, lon_expanded]) # shape: (num_raw_samples*image_size**2, 3)
        self.y_cpu = self.dataset.reshape(-1, 1) # shape: (num_raw_samples*image_size**2, 1)

        self.dim_input = self.x_cpu.shape[1]
        self.dim_output = self.y_cpu.shape[1]
        
    def plot_raw_sample(self, idx: int) -> None:
        '''
        Plot a sample of the dataset.
        '''
        candidate_days = [1, 5, 10, 15, 20, 25]
        candidate_hours = [0, 6, 12, 18]
        year = [2020, 2021, 2022, 2023, 2024][idx//(4*6*12) % 5]
        month = idx // (4*6) + 1
        day = candidate_days[idx//4 % 6]
        hour = candidate_hours[idx % 4]
        
        plt.rcParams.update({'font.size': 14})
        plt.imshow(self.dataset[idx])
        plt.title(f"2m temperature - Year: {year}, Month: {month}, Day: {day}, Hour: {hour}", fontsize=16)
        plt.xlabel("Longitude", fontsize=14)
        plt.ylabel("Latitude", fontsize=14)
        plt.savefig(os.path.join(path, 'data', 'era5_t2m_%d.png'%(idx)))
        plt.close()
        
    def plot_mean_std(self, i_year: int, i_month: int) -> None:
        '''
        Plot the mean and std of the dataset.
        '''
        plt.rcParams.update({'font.size': 14})
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        
        dataset = self.dataset[i_year*self.n_month*self.n_day*self.n_hour:(i_year+1)*self.n_month*self.n_day*self.n_hour, :, :]
        mean = np.mean(dataset, axis=0)
        std = np.std(dataset, axis=0)
        
        # Plot mean
        im1 = ax1.imshow(mean)
        ax1.set_title(f"Mean", fontsize=20)
        ax1.set_xlabel("Longitude", fontsize=20)
        ax1.set_ylabel("Latitude", fontsize=20)
        cbar1 = plt.colorbar(im1, ax=ax1)
        cbar1.ax.tick_params(labelsize=16)
        
        # Plot standard deviation
        im2 = ax2.imshow(std)
        ax2.set_title(f"Standard Deviation", fontsize=20)
        ax2.set_xlabel("Longitude", fontsize=20)
        ax2.set_ylabel("Latitude", fontsize=20)
        cbar2 = plt.colorbar(im2, ax=ax2)
        cbar2.ax.tick_params(labelsize=16)
        
        plt.suptitle(f"2m temperature (normalized) - Year: {2020+i_year}, Month: {1+i_month}", fontsize=20)
        
        plt.tight_layout()
        plt.savefig(os.path.join(path, 'data', 'era5_t2m_mean_std_%d_%d.png'%(2020+i_year, 1+i_month)))
        plt.savefig(os.path.join(path, 'data', 'era5_t2m_mean_std_%d_%d.pdf'%(2020+i_year, 1+i_month)), dpi=300)
        plt.close()


def normalize_range(x, low=-1, high=1):
    """
    Normalizes values to a specified range.
    :param x: input value
    :param low: low end of the range
    :param high: high end of the range
    :return: normalized value
    """
    x = (x - x.min()) / (x.max() - x.min())
    x = ((high - low) * x) + low
    return x


if __name__ == "__main__":
    
    dataset = ERA5(image_size=64, split="train", download=True)
    
    dataset.plot_raw_sample(0)
    dataset.plot_raw_sample(5)
    dataset.plot_mean_std(4, 0)
    
    print(f'Dataset size: {len(dataset)}')
    print(f'Input dimension: {dataset.dim_input}')
    print(f'Output dimension: {dataset.dim_output}')
    print(f'Input data: {dataset[0][0]}')
    print(f'Output data: {dataset[0][1]}')
    