import os
from typing import Optional, Tuple

import numpy as np  # type: ignore
import torch
from torch.utils.data import DataLoader

from data.tabular import PBPDataset  # type: ignore

T = torch.Tensor


# "protein-tertiary-structure",
pbp_sets = [
    "boston-housing", "concrete", "energy",
    "kin8nm", "naval-propulsion-plant", "power-plant",
    "wine-quality-red", "yacht"
]

batch_sizes = {
    "boston-housing": 50,
    "concrete": 100,
    "energy": 100,
    "kin8nm": 500,
    "naval-propulsion-plant": 500,
    "power-plant": 500,
    "wine-quality-red": 50,
    "yacht": 50
}


def get_pbp_idx_kmeans(datapath: str, run: int, shifted: bool = False) -> Tuple[T, T, Optional[T]]:
    """get the pbp dataset indices for the regular and the shifted datasets"""
    if not shifted:
        # get the random train/test splits from MC Dropout experiments
        path = os.path.join(datapath, f"index_train_{run}.txt")
        train_idx = torch.from_numpy(np.loadtxt(path)).long()

        path = os.path.join(datapath, f"index_test_{run}.txt")
        test_idx = torch.from_numpy(np.loadtxt(path)).long()

        return train_idx, test_idx, None

    shift_idx_path = os.path.join(datapath, "shifted", f"{run}")
    # get the random train/test splits from the clustered and shifted datasets
    test_idx_path = os.path.join(shift_idx_path, "index_test_cluster.txt")
    test_idx = torch.from_numpy(np.loadtxt(test_idx_path)).long()

    train_cluster_class = torch.empty(0).long()
    train_idx = torch.empty(0).long()
    for i in range(4):
        train_idx_path = os.path.join(shift_idx_path, f"index_train_cluster_{i}.txt")
        tmp = torch.from_numpy(np.loadtxt(train_idx_path)).long()
        train_idx = torch.cat((train_idx, tmp), dim=0)
        train_cluster_class = torch.cat(
            (train_cluster_class, torch.ones(tmp.size(0)).long() * i), dim=0
        )

    return train_idx, test_idx, train_cluster_class


def get_pbp_idx_spectral(datapath: str, run: int, shifted: bool = False) -> Tuple[T, T, Optional[T]]:
    """get the pbp dataset indices for the regular and the shifted datasets"""
    if not shifted:
        # get the random train/test splits from MC Dropout experiments
        path = os.path.join(datapath, f"index_train_{run}.txt")
        train_idx = torch.from_numpy(np.loadtxt(path)).long()

        path = os.path.join(datapath, f"index_test_{run}.txt")
        test_idx = torch.from_numpy(np.loadtxt(path)).long()

        return train_idx, test_idx, None

    shift_idx_path = os.path.join(datapath, "shifted-spectral-max", f"{run}")
    # get the random train/test splits from the clustered and shifted datasets

    idx_files = os.listdir(shift_idx_path)

    test_idx = torch.empty(0).long()
    for f in [v for v in idx_files if "test" in v]:
        test_idx_path = os.path.join(shift_idx_path, f)
        tmp = torch.from_numpy(np.loadtxt(test_idx_path)).long()
        test_idx = torch.cat((test_idx, tmp), dim=0)

    train_cluster_class = torch.empty(0).long()
    train_idx = torch.empty(0).long()
    for i, f in enumerate([v for v in idx_files if "train" in v]):
        train_idx_path = os.path.join(shift_idx_path, f)
        tmp = torch.from_numpy(np.loadtxt(train_idx_path)).long()
        train_idx = torch.cat((train_idx, tmp), dim=0)
        train_cluster_class = torch.cat((train_cluster_class, torch.ones(tmp.size(0)).long() * i), dim=0)

    return train_idx, test_idx, train_cluster_class


def get_run_folds(dataset: str, datadir: str, run: int) -> int:
    """get the total number of training sets to decide the number of folds for kfold validation"""
    data_path = os.path.join(datadir, "UCI_Datasets", dataset, "data", "shifted-spectral-max", f"{run}")
    idx_files = os.listdir(data_path)

    return len([v for v in idx_files if "train" in v])


class Loader(DataLoader):
    pass


def get_pbp_sets(
    name: str,
    datadir: str,
    run: int,
    val_pct: float = 0.1,
    get_val: bool = True,
    shifted: bool = False,
) -> Tuple[Loader, Loader, Loader]:
    """
    retrieves the datasets which were used in the following papers
    http://papers.nips.cc/paper/7219-simple-and-scalable-predictive-uncertainty-estimation-using-deep-ensembles.pdf
    https://arxiv.org/abs/1502.05336
    https://arxiv.org/abs/1506.02142

    val_cluster: added this arg so that we can specify which cluster is used for validation (like when doing k-fold).
      - if the arg is -1 (default) then it will default to using the cluster closes to 20% of the total size
    """
    if name not in pbp_sets:
        raise ValueError(f"{name} is an unknown pbp dataset")

    data_path = os.path.join(datadir, "UCI_Datasets", name, "data")
    # load data
    path = os.path.join(data_path, "data.txt")
    data = torch.from_numpy(np.loadtxt(path)).float()

    train_idx, test_idx, cluster_idx = get_pbp_idx_spectral(data_path, run, shifted=shifted)

    if not get_val:
        # extract the features and labels
        train_ft = data[train_idx, :-1]
        train_label = data[train_idx, -1]

        test_ft = data[test_idx, :-1]
        test_label = data[test_idx, -1]

        train = PBPDataset(x=train_ft, y=train_label, cluster_idx=cluster_idx, name=name)
        test = PBPDataset(x=test_ft, y=test_label, name=name)

        xparams = train.standard_normalize_x()
        yparams = train.standard_normalize_y()
        test.standard_normalize_x(*xparams)
        test.standard_normalize_y(*yparams)

        train_ldr = Loader(train, shuffle=True, batch_size=batch_sizes[name], drop_last=True)
        return (train_ldr, train_ldr, Loader(test, batch_size=batch_sizes[name]))

    # the unshifted datasets were already shuffled when indices were made, linspace uses the same val
    # indices as MC Dropout. The making of the shifted datasets resulted in the train indices being
    # grouped by kmeans cluster label so we need to select a random permutation
    train_n = int(len(train_idx) * (1 - val_pct))
    perm = torch.linspace(0, train_idx.size(0) - 1, train_idx.size(0)).long()

    train_cluster_idx, val_cluster_idx = None, None  # type: ignore

    if not shifted:
        # dataset is not shifted, proceed as normal
        b = perm.size(0) // 2
        t, v = perm[b:], perm[:b]

        train_ft = data[train_idx[t], :-1]
        val_ft = data[train_idx[v], :-1]

        train_label = data[train_idx[t], -1]
        val_label = data[train_idx[v], -1]
    elif shifted:
        if cluster_idx is None:
            raise ValueError("cluster idx cannot be none")

        train_cluster_idx = cluster_idx[perm[:train_n]]
        val_cluster_idx = cluster_idx[perm[train_n:]]

    test_ft = data[test_idx, :-1]
    test_label = data[test_idx, -1]

    # because of the spectral clustering, there are some datasets which may only have a few examples.
    # in these cases, make sure we have at least 100...
    while train_ft.shape[0] < 100:
        train_ft = train_ft.repeat(2, 1)
        train_label = train_label.repeat(2)

    train = PBPDataset(x=train_ft, cluster_idx=train_cluster_idx, y=train_label, name=name)
    val = PBPDataset(x=val_ft, cluster_idx=val_cluster_idx, y=val_label, name=name)
    test = PBPDataset(x=test_ft, y=test_label, name=name)

    xparams = train.standard_normalize_x()
    yparams = train.standard_normalize_y()

    val.standard_normalize_x(*xparams)
    val.standard_normalize_y(*yparams)

    test.standard_normalize_x(*xparams)
    test.standard_normalize_y(*yparams)

    return (
        Loader(train, shuffle=True, batch_size=batch_sizes[name]),
        Loader(val, batch_size=len(val)),
        Loader(test, batch_size=len(test))
    )
