import netCDF4
import torch
from torch.utils.data import DataLoader, TensorDataset
import lightning as pl

import sys
sys.path.append("../..")

import numpy as np

from pathlib import Path

from utils.utils import get_output_root

def nc_to_tensor_ts(file_path):
    """
    Convert a NetCDF file to tensors with multiple timesteps.

    Args:
        file_path (str): The path to the NetCDF file to read.

    Returns:
        tuple: A tuple containing:
            - coords (torch.Tensor): A tensor of shape (n, t, 3) with latitude, longitude, and time coordinates.
            - targets (torch.Tensor): A tensor of shape (n, t, nc) with variable data from the NetCDF file.

    Raises:
        FileNotFoundError: If the tensor file(s) do not exist, triggering tensor creation.
    """
    file_stem = str(file_path).split('/')[-1].split('.')[0]
    artefact_path = get_output_root() / f'notebooks/intermediate_artefacts/'
    artefact_path.mkdir(parents=True, exist_ok=True)

    coords_path = artefact_path / f'coords_v8_{file_stem}.pt'
    targets_path = artefact_path / f'targets_v8_{file_stem}.pt'
    
    try:              
        coords, targets = torch.load(coords_path), torch.load(targets_path)
    except FileNotFoundError:
        print("File not found, creating new tensors...")

        with netCDF4.Dataset(file_path, 'r') as file2read:
            tensor_list = []
            t = file2read.variables['DLWRFsfc'][:].shape[0]
            h = file2read.variables['DLWRFsfc'][:].shape[1]
            w = file2read.variables['DLWRFsfc'][:].shape[2]
            nc = 55
            vars = list(file2read.variables.keys())
            vars.remove('time')
            vars.remove('grid_xt')
            vars.remove('grid_yt')

            for i in vars:
                # Get the column data from the NetCDF file
                column = torch.tensor(file2read.variables[i][:])

                if len(column.shape) == 0:
                    # Skip scalar variables
                    continue
                if len(column.shape) == 1:
                    if column.shape == [h]:
                        # Expand 1D column to 3D tensor with shape (t, h, w)
                        column = column.unsqueeze(1).unsqueeze(0).expand(t, h, w)
                    if column.shape == [w]:
                        # Expand 1D column to 3D tensor with shape (t, h, w)
                        column = column.unsqueeze(0).unsqueeze(0).expand(t, h, w)
                if len(column.shape) == 2:
                    # Expand 2D column to 3D tensor with shape (t, h, w)
                    column = column.unsqueeze(0).expand(t, h, w)

                # Normalize each channel
                # column = (column - column.min() )/ (column.max() - column.min())
                tensor_list.append(column)

            targets = torch.stack(tensor_list, 0)

            # Get lon/lat/time coordinates
            lon = torch.tensor(file2read.variables['grid_xt'][:] - 180)
            lat = torch.tensor(file2read.variables['grid_yt'][:])
            time = torch.tensor(file2read.variables['time'][:]).to(torch.float64)

            time_grid, lat_grid, lon_grid = torch.meshgrid(time, lat, lon)
            coords = torch.stack([
                lon_grid, 
                lat_grid, 
                time_grid
                ], axis=0)

            # Reshape
            # coords = coords.reshape(3, -1, t).transpose(1, 2).transpose(0, 2)
            # targets = targets.reshape(nc, -1, t).transpose(1, 2).transpose(0, 2)
            # targets = targets.reshape(nc, t, -1).transpose(0, 2)


            # coords have shape (3, t, h, w)
            # reshape coords to (3, t, w, h)
            coords = coords.transpose(2, 3)
            # reshape coords to (3, t, w*h)
            coords = coords.reshape(3, t, -1)
            # reshape coords to (w*h, t, 3)
            coords = coords.transpose(0, 2)

            # targets have shape (nc, t, h, w)
            # reshape targets to (nc, t, w, h)
            targets = targets.transpose(2, 3)
            # reshape targets to (nc, t, w*h)
            targets = targets.reshape(nc, t, -1)
            # reshape targets to (w*h, t, nc)
            targets = targets.transpose(0, 2)

            # Save coords and targets to a file
            torch.save(coords, coords_path)
            torch.save(targets, targets_path)
            
    return coords, targets


def split_ace_dataframe_ts(
        targets, 
        train_fraction=0.2, 
        val_fraction=0.1,
        test_fraction=0.7,
        T = None,
        mode = "spatio_temporal_interpolation",
        val_threshold_forecast = 0.4,
        test_threshold_forecast = 0.5,
        ):
    """
    Split the ACE dataset targets into training, validation, and test selectors.

    Args:
        targets (torch.Tensor): The target tensor with shape (n, t, ...).
        train_fraction (float, optional): Fraction of data for training. Defaults to 0.2.
        val_fraction (float, optional): Fraction of data for validation. Defaults to 0.1.
        test_fraction (float, optional): Fraction of data for testing. Defaults to 0.7.
        T (int, optional): Number of timesteps to consider. If None, uses targets.shape[1].
        mode (str, optional): Splitting mode ("spatio_temporal_interpolation" or "forecast").


    Returns:
        tuple: A tuple of selectors (train_selector, val_selector, test_selector) as boolean masks.
    """
    assert val_threshold_forecast < test_threshold_forecast
    N = targets.shape[0]
    T = targets.shape[1] if T is None else T

    if train_fraction+val_fraction+test_fraction > 1:
        raise ValueError("Sum of train_fraction+val_fraction+test_fraction must be less than or equal to 1.")   
    
    print("Create split selectors...")

    if mode=="spatio_temporal_interpolation":
        train_size = int(N*T*train_fraction)
        val_size = int(N*T*val_fraction)
        test_size = int(N*T*test_fraction)

        indexes = torch.randperm(N*T)
        
        train_selector = torch.zeros(N*T, dtype=torch.bool)
        val_selector = torch.zeros(N*T, dtype=torch.bool)   
        test_selector = torch.zeros(N*T, dtype=torch.bool)

        train_selector[:train_size] = True
        val_selector[train_size:train_size+val_size] = True
        test_selector[train_size+val_size:train_size+val_size+test_size] = True

        train_selector = train_selector[indexes]
        val_selector = val_selector[indexes]
        test_selector = test_selector[indexes]

        return train_selector, val_selector, test_selector
    elif mode == "forecast":
        train_size = int(N * train_fraction)
        val_size = int(N * val_fraction)
        test_size = int(N * test_fraction)

        train_steps = int(T * val_threshold_forecast)
        val_steps = int(T * (test_threshold_forecast-val_threshold_forecast))

        train_selector = torch.zeros((N, T), dtype=torch.bool)
        val_selector = torch.zeros((N, T), dtype=torch.bool)
        test_selector = torch.zeros((N, T), dtype=torch.bool)  # test uses all timesteps
        
        train_selector[:train_size, :train_steps] = True
        val_selector[:val_size, train_steps:train_steps+val_steps] = True
        test_selector[:test_size, :] = True

        # shuffle the first axis of train_selector, val_selector, test_selector
        train_selector = train_selector[torch.randperm(N), :].flatten()
        val_selector = val_selector[torch.randperm(N), :].flatten()
        test_selector = test_selector[torch.randperm(N), :].flatten()

        return train_selector, val_selector, test_selector
    else:
        raise NotImplementedError

class ACE_TS_DataModule(pl.LightningDataModule):
    def __init__(self, 
        file_paths,
        train_fraction=0.2, 
        val_fraction=0.1,
        test_fraction=0.7,
        num_timesteps=None,
        mode="spatio_temporal_interpolation", 
        num_workers=0, 
        batch_size=1000,
        perturbed_training=False,
        perturbation_scale=0.1,
        shuffle_training_data=True,
        variable_selection=list(range(54, 55)),
        subset_fraction=1.0,
    ):
        """
        Initialize the ACE_TS_DataModule.

        Args:
            file_paths (list): List of NetCDF file paths.
            train_fraction (float): Fraction of data used for training.
            val_fraction (float): Fraction of data used for validation.
            test_fraction (float): Fraction of data used for testing.
            num_timesteps (int, optional): Number of timesteps to use. Defaults to None.
            mode (str): Mode for splitting data ("spatio_temporal_interpolation" or "forecast").
            num_workers (int): Number of workers for data loading.
            batch_size (int): Batch size for training.
            perturbed_training (bool): Whether to apply perturbations to training targets.
            perturbation_scale (float): Scale of the perturbation.
            shuffle_training_data (bool): Whether to shuffle training data.
            variable_selection (list): List of variable indices to select.
            subset_fraction (float): Fraction of spatial data to use.
        """
        super().__init__()
        self.file_paths = file_paths
        self.train_fraction = train_fraction
        self.val_fraction = val_fraction
        self.test_fraction = test_fraction
        self.mode = mode
        self.num_timesteps = num_timesteps
        self.num_workers = num_workers
        self.batch_size = batch_size

        self.perturbed_training = perturbed_training
        self.perturbation_scale = perturbation_scale

        self.shuffle_training_data = shuffle_training_data

        self.variable_selection = variable_selection
        self.subset_fraction = subset_fraction

        self.artifact_path = self.build_artifact_path(self.mode, self.train_fraction, self.file_paths, self.subset_fraction)

    def prepare_data(self):
        def process_file(file_path):
            coords, targets = nc_to_tensor_ts(file_path)

            selection = torch.Tensor(self.variable_selection).long()
            selected_targets = torch.index_select(targets, dim=2, index=selection.clone().detach())
            self.C = selected_targets.shape[2]

            return coords, selected_targets        
        
        self.coords_list, self.selected_targets_list = tuple(zip(*[process_file(get_output_root() / file_path) for file_path in self.file_paths]))
        
        # concatenate data along time dimension
        self.coords = torch.concat(self.coords_list, dim=1) # time is dim=1
        self.selected_targets = torch.concat(self.selected_targets_list, dim=1) # time is dim=1

        # select a subset of the data
        self.spatial_selector = torch.rand(self.coords.shape[0]) < self.subset_fraction
        self.coords = self.coords[self.spatial_selector, :, :]
        self.selected_targets = self.selected_targets[self.spatial_selector, :, :]

        # normalize time coordinate
        self.coords[:, :, 2] = (
            (self.coords[:, :, 2] - self.coords[:, :, 2].min()) / (self.coords[:, :, 2].max() - self.coords[:, :, 2].min()) # normalize time
            - 0.5 # center around 0
        ) * 2 # scale to [-1, 1]

        # self.coords.shape is assumed to have value (N_SPATIAL_GRIDPOINTS, N_TIME_GRIDPOINTS, N_COORDINATES)
        self.num_timesteps = self.coords.shape[1] if self.num_timesteps is None else self.num_timesteps 
        
        if self.mode == "spatio_temporal_interpolation" or self.mode == "forecast":
            self.train_selector, self.val_selector, self.test_selector = split_ace_dataframe_ts(
                targets=self.selected_targets,
                train_fraction=self.train_fraction, 
                val_fraction=self.val_fraction, 
                test_fraction=self.test_fraction,
                T=self.num_timesteps, 
                mode=self.mode
                )
        else:
            raise NotImplementedError


    def setup(self, stage: str):
        """
        Sets up the data for the specified stage.

        Args:
            stage (str): The stage for which to set up the data. Can be one of "fit", "validate", "test", or "predict".

        Returns:
            None
        """
        if self.mode == "spatio_temporal_interpolation" or self.mode == "forecast":
            self.coords = self.coords.reshape(-1, 3)
            self.selected_targets = self.selected_targets.reshape(-1, self.C)

            # compute the mean and std of the training data
            self.means = torch.mean(self.selected_targets[self.train_selector], dim=0, keepdim=True)
            self.stds = torch.std(self.selected_targets[self.train_selector], dim=0, keepdim=True)

            # normalize the data
            self.selected_targets = (self.selected_targets - self.means) / self.stds

            if stage=="fit" or stage is None:
                # target_perturbation = torch.randn_like(self.selected_targets) * self.perturbation_scale * self.stds

                self.train_ds = TensorDataset(
                    self.coords[self.train_selector],
                    self.selected_targets[self.train_selector]
                    )

                self.valid_ds = TensorDataset(
                    self.coords[self.val_selector],
                    self.selected_targets[self.val_selector]
                    )

            elif stage=="test" or stage=="predict":
                self.test_ds = TensorDataset(
                    self.coords[self.test_selector],
                    self.selected_targets[self.test_selector]
                )
                
                self.predict_ds = TensorDataset(
                    torch.concatenate((self.coords[self.train_selector], 
                                      self.coords[self.val_selector], 
                                      self.coords[self.test_selector]),
                                      dim=0
                                      ),
                    torch.concatenate(
                        (self.selected_targets[self.train_selector], 
                        self.selected_targets[self.val_selector], 
                        self.selected_targets[self.test_selector]),
                        dim=0
                    )   
                )
        else:
            raise NotImplementedError

    def train_dataloader(self):
        return DataLoader(
             self.train_ds, 
             batch_size=self.batch_size, 
             num_workers=self.num_workers, 
             shuffle=self.shuffle_training_data, 
             pin_memory=True
             )

    def val_dataloader(self):
        return DataLoader(self.valid_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)

    def predict_dataloader(self):
        return DataLoader(
             self.predict_ds, 
             batch_size=len(self.predict_ds) // self.num_timesteps, 
             num_workers=self.num_workers, 
             shuffle=False,
             )

    def build_artifact_path(self, mode, train_fraction, file_paths, spatial_subset_fraction):
        return (Path("ACE") / 
            f"task_{mode}" /
            f"train_fraction_{train_fraction}" /
            f"spatial_subset_fraction_{spatial_subset_fraction}" /
            f"number_input_files_{len(file_paths)}" 
        )
