from sklearn.preprocessing import MaxAbsScaler, OneHotEncoder
import scipy.sparse as sparse
from pathlib import Path
import random
from typing import List
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
import torch
from src.seq_dataset import SeqDataset
proj_root = Path(__file__).parent.parent.parent.resolve()


def samples_to_datasets(X, y, N: int, seed: int, T: int):
    # transform the sequence of x [o,o,o,o,...] into buckets of points to form a sequence of datasets : [[o,o,o], [o,o,o],...]
    # we for max_t datasets of close to N samples each.
    max_T = int(X.shape[0]/N)
    X_split = np.array_split(X, max_T)
    y_split = np.array_split(y, max_T)

    # randomly select a sequence of T datasets:
    random.seed(seed)
    num_start_index_choices = len(X_split) - T
    start_index_split = random.randint(1, num_start_index_choices)

    X_split = X_split[start_index_split:start_index_split+T]
    y_split = y_split[start_index_split:start_index_split+T]

    return X_split, y_split


def get_elec_data(data_dir: Path, seed: int,  N: int = 100, T: int = 20):
    electricity_dir = data_dir/"elec"
    df = pd.read_csv(electricity_dir/"raw.csv")
    X = df.iloc[:, :-1].values
    y = df.iloc[:, -1].values
    scaler = MinMaxScaler()
    X_scaled = scaler.fit_transform(X)
    X_split, y_split = samples_to_datasets(X_scaled, y, N, seed, T)
    return X_split, y_split


def get_epicgames_data(data_dir: Path, seed: int,  N: int = 100, T: int = 20):
    path = data_dir/"epicgame/epicgames.pt"

    dataset = torch.load(path)
    X = dataset.msg.numpy()
    y = dataset.y.numpy()

    X_split, y_split = samples_to_datasets(X, y, N, seed, T)

    return X_split, y_split


def get_yelp_data(data_dir: Path, seed: int,  N: int = 100, T: int = 20):
    path = data_dir/"yelp/yelpchi.pt"

    dataset = torch.load(path)[0]
    X = dataset.msg.numpy()
    y = dataset.y.numpy()

    X_split, y_split = samples_to_datasets(X, y, N, seed, T)

    return X_split, y_split


def get_airplanes_data(data_dir: Path, seed: int,  N: int = 100, T: int = 20, onehot=False):

    numerical_features = [ "Time","Length"]
    categorical_features = ["Airline",
                            "Flight",
                            "AirportFrom",
                            "AirportTo",
                            "DayOfWeek"]
    airlines_dir = data_dir / "airlines"
    df = pd.read_csv(airlines_dir / "raw.csv")
    if onehot:
        X_categorical = df[categorical_features]
        X_numerical = df[numerical_features]
        y = df.iloc[:, -1].values
        enc = OneHotEncoder()
        X_one_hot = enc.fit_transform(X_categorical)

        X_full = sparse.hstack([X_one_hot, X_numerical])
        scaler = MaxAbsScaler()
        X_scaled = scaler.fit_transform(X_full).tocsr()
        X_split, y_split = samples_to_datasets(X_scaled, y, N, seed, T)
    else:
        X_full = df.iloc[:, :-1].values
        y = df.iloc[:, -1].values
        X_split, y_split = samples_to_datasets(X_full, y, N, seed, T)
    return X_split, y_split


def load_temporal_dataset(cfg, seed, dataset_name):
    offline_split = cfg['offline_t']
    test_split = cfg['test_frac']
    size_datasets = cfg['N']
    T = cfg['T']
    data_dir = proj_root/cfg['data_dir']
    if dataset_name == 'electricity':
        X_full, y_full = get_elec_data(data_dir, seed, size_datasets, T)
    if dataset_name == 'airplanes':
        X_full, y_full = get_airplanes_data(data_dir, seed, size_datasets, T)
    if dataset_name == 'epicgames':
        X_full, y_full = get_epicgames_data(data_dir, seed, size_datasets, T)
    if dataset_name == 'yelp':
        X_full, y_full = get_yelp_data(data_dir, seed, size_datasets, T)

    seq_dataset = SeqDataset(
        X_full, y_full, offline_split, test_split, val_split=0.1, seed=seed)
    return seq_dataset
