import os
import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T

from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data

from deephfts.experiments.mimic.dataloader import mimic_trainval_loader_helper

class SlidingWindowDataset(Dataset):
    """Sliding window dataset. 
    """
    def __init__(self, data, transform=None, target_transform=None, window_length : int = 80, forecast_length : int = 20):
        self.data = data
        self.transform = transform
        self.target_transform = transform
        self.window_length = window_length
        self.forecast_length = forecast_length
    
    def __len__(self):
        return len(self.data) - self.window_length - self.forecast_length
    
    def __getitem__(self, idx):
        window = self.data.loc[idx:idx+self.window_length-1, 'value'].to_numpy()
        forecast_idx_start = idx+self.window_length
        forecast_idx_end = idx+self.window_length+self.forecast_length-1
        forecast = self.data.loc[forecast_idx_start:forecast_idx_end, 'value'].to_numpy()

        window = torch.Tensor(window)
        forecast = torch.Tensor(forecast)

        if(self.transform):
            window = self.transform(window)
        
        if(self.target_transform):
            forecast = self.target_transform(forecast)
        
        return window, forecast

class BatchWindowDataset(Dataset):
    """Batch data set. 
    """
    def __init__(self, data, transform=None, target_transform=None, window_length : int = 80, forecast_length : int = 20):
        self.data = data
        self.transform = transform
        self.target_transform = transform
        self.window_length = window_length
        self.forecast_length = forecast_length
    
    def __len__(self):
        return self.data['series'].max()
    
    def __getitem__(self, idx):
        time_series = self.data.loc[self.data['series'] == idx]
        window = time_series.iloc[0:self.window_length]['value'].to_numpy()
        forecast_idx_start = self.window_length
        forecast_idx_end = self.window_length+self.forecast_length
        forecast = time_series.iloc[forecast_idx_start:forecast_idx_end]['value'].to_numpy()

        window = torch.Tensor(window)
        forecast = torch.Tensor(forecast)

        if(self.transform):
            window = self.transform(window)
        
        if(self.target_transform):
            forecast = self.target_transform(forecast)
        
        return window, forecast




def online_trainloader_helper(
    data, 
    max_encoder_length: int = 80,
    max_prediction_length: int = 20,
    batch_size = 1):
    """Method to generate a train loader for online learning. 

    Args:
        data ([type]): Data to be loaded. Must be pandas dataframe with columns ["time_idx", "series", "value"]
        max_encoder_length (int, optional): [description]. Defaults to 80.
        max_prediction_length (int, optional): [description]. Defaults to 20.
        batch_size (int, optional): [description]. Defaults to 1.
    """
    dataset = SlidingWindowDataset(data=data, 
        transform=None, target_transform=None, 
        window_length=max_encoder_length, 
        forecast_length=max_prediction_length
    )

    online_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    offline_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return online_dataloader, offline_dataloader

def batch_trainloader_helper(
    data, 
    max_encoder_length: int = 80,
    max_prediction_length: int = 20,
    batch_size = 1,
    train_split = 0.80):
    """Method to generate a train loader for online learning. 

    Args:
        data (pd.DataFrame): Data to be loaded. Must be pandas dataframe with columns ["time_idx", "series", "value"]
        max_encoder_length (int, optional): [description]. Defaults to 80.
        max_prediction_length (int, optional): [description]. Defaults to 20.
        batch_size (int, optional): [description]. Defaults to 1.
    """
    dataset = BatchWindowDataset(data=data, 
        transform=None, target_transform=None, 
        window_length=max_encoder_length, 
        forecast_length=max_prediction_length
    )
    train_length = int(len(dataset) * train_split)
    valid_length = len(dataset) - train_length

    train_dataset, valid_dataset = random_split(dataset, lengths=[train_length, valid_length]) 

    train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    valid_dataloader= DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return train_dataloader, valid_dataloader

def offline_trainloader_helper(
    data,
    max_encoder_length: int = 80,
    max_prediction_length: int = 20,
    batch_size: int = 1):
    """Code to generate Offline Train-Loader. Offline methods can sample in any order. Online methods cannot. 

    Args:
        data (pd.DataFrame): Data to be loaded. Must be pandas dataframe with columns ["time_idx", "series", "value"]
        max_encoder_length (int, optional): Encoder length. Defaults to 80.
        max_prediction_length (int, optional): Prediction length. Defaults to 20.
        batch_size (int, optional): Batch size. Defaults to 1.

    Returns:
        train_dataloader, val_dataloader, training, validation
    """
    # create training cutoff. 

    training_cutoff = data["time_idx"].max() - max_prediction_length

    # establish context and prediction length. 
    context_length = max_encoder_length
    prediction_length = max_prediction_length

    training = TimeSeriesDataSet(
        data[lambda x: x.time_idx <= training_cutoff],
        time_idx="time_idx",
        target="value",
        group_ids=["series"],
        time_varying_unknown_reals=["value"],
        max_encoder_length=context_length,
        max_prediction_length=prediction_length,
    )
    validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training_cutoff + 1)
    
    train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
    val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

    return train_dataloader, val_dataloader, training, validation


def offline_to_online_loader(
    data,
    max_encoder_length: int = 80,
    max_prediction_length: int = 20,
    batch_size: int = 1):
    """Code to generate Online Train-Loader from Offline Time Series. 

    Args:
        data ([type]): Data to be loaded. Must be pandas dataframe with columns ["time_idx", "series", "value"]
        max_encoder_length (int, optional): Encoder length. Defaults to 80.
        max_prediction_length (int, optional): Prediction length. Defaults to 20.
        batch_size (int, optional): Batch size. Defaults to 1.

    Returns:
        train_dataloader, val_dataloader, training, validation
    """
    context_length = max_encoder_length
    prediction_length = max_prediction_length

    training = TimeSeriesDataSet(
            data,
            time_idx="time_idx",
            target="value",
            group_ids=["series"],
            time_varying_unknown_reals=["value"],
            max_encoder_length=context_length,
            max_prediction_length=prediction_length,
        )
    validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=0+1)
    
    train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
    val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

    return train_dataloader, val_dataloader, training, validation

def ar_dataset(
        timesteps: int = 400,
        n_series : int = 10,
        seed : int = 213,
        seasonality : float = 3.0,
        trend : float = 3.0,
        max_encoder_length: int = 80,
        max_prediction_length: int = 20,
        batch_size: int = 1
    ):
    """Implementation of an AR dataset. 
    Borrowed from https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html 
    """
    # generate ar data. 
    data = generate_ar_data(seasonality=seasonality, timesteps=timesteps, n_series=n_series, seed=seed, trend=trend)
    train_dataloader, val_dataloader = offline_trainloader_helper(data, 
        max_encoder_length=max_encoder_length,
        max_prediction_length=max_prediction_length,
        batch_size=batch_size)

    return train_dataloader, val_dataloader


def online_ar_dataset(max_encoder_length, max_prediction_length, breakpoint):
    """Return time series version of AR dataset. 
    """
    train_loader, _ = ar_dataset(max_encoder_length=max_encoder_length, max_prediction_length=max_prediction_length)
    ts = list()
    for index, (x, y) in enumerate(train_loader):
        x = x['encoder_target'][0]
        y = y[0]
        ts.append(torch.flatten(x))
        ts.append(torch.flatten(y))
        if(index > breakpoint):
            break
    ts = torch.concat(ts)
    return ts

def dataset_helper(dataset=None, mode: str = "online", seed: int=213, path=None, channel=None, window_size: int = 80, forecast_size: int = 20):
    """Method to determine the dataset to load. 

    Args:
        dataset ([type], optional): Name of the dataset. Defaults to None.
        mode (str, optional): Mode. Defaults to "online".
        seed (int, optional): Seed. Defaults to 213.
        path ([type], optional): Path to csv file. Defaults to None.
        channel ([type], optional): ECG channel for MIMIC data. Defaults to None.
        window_size (int, optional): Window size. Defaults to 80.
        forecast_size (int, optional): Forecast size. Defaults to 20.

    Returns:
        Dataloader.
    """
    if(dataset is None and path is not None and mode == "online"):
        data = pd.read_csv(path)
        return online_trainloader_helper(data, 
            max_encoder_length=window_size,
            max_prediction_length=forecast_size
        )
    
    if(dataset is None and path is not None and mode == "offline"):
        data = pd.read_csv(path)
        return batch_trainloader_helper(data, 
            max_encoder_length=window_size,
            max_prediction_length=forecast_size
        )

    if(dataset == "ar"):
        return ar_dataset(seed=seed)
    
    elif(dataset == "mimic" and path and channel):
        train_dataloader, test_dataloader = mimic_trainval_loader_helper(
            path=path,
            channel_II=channel
        )
        return train_dataloader, test_dataloader
    
    