import csv
import os
import pickle
import time

import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, TensorDataset
from torchvision.datasets import CIFAR10, MNIST, FashionMNIST

from typing import cast
from torch import Tensor

from datasets.wrappers import BSDS300, GAS, HEPMASS, MINIBOONE, POWER
from datasets.wrappers.artificial import (
    banana_sample,
    cosine_sample,
    funnel_sample,
    multi_rings_sample,
    rotate_samples,
    single_ring_sample,
    spiral_sample,
    crossing_rings_sample,
)
from datasets.wrappers.celeba import CELEBA

SMALL_UCI_DATASETS = ["biofam", "flare", "lymphography", "spect", "tumor", "votes"]

BINARY_DATASETS = [
    "accidents",
    "ad",
    "baudio",
    "bbc",
    "binarized_mnist",
    "bnetflix",
    "book",
    "c20ng",
    "cr52",
    "cwebkb",
    "dna",
    "jester",
    "kdd",
    "kosarek",
    "msnbc",
    "msweb",
    "mushrooms",
    "nltcs",
    "ocr_letters",
    "plants",
    "pumsb_star",
    "tmovie",
    "tretail",
]

IMAGE_DATASETS = ["MNIST", "FashionMNIST", "CIFAR10", "CelebA"]

CONTINUOUS_DATASETS = ["power", "gas", "hepmass", "miniboone", "bsds300"]

ARTIFICIAL_DATASETS = [
    "ring",
    "mring",
    "funnel",
    "banana",
    "cosine",
    "spiral",
    "crossing-rings",
]

LANGUAGE_DATASETS = []


ALL_DATASETS = (
    SMALL_UCI_DATASETS
    + BINARY_DATASETS
    + IMAGE_DATASETS
    + CONTINUOUS_DATASETS
    + ARTIFICIAL_DATASETS
)


def load_small_uci_dataset(
    name: str, path: str = "datasets", dtype: str = "int64", seed: int = 42
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Load and split small UCI datasets.
    """
    with open(os.path.join(path, name), "rb") as f:
        data = pickle.load(f, encoding="latin1")[0]
        data = data.astype(dtype, copy=False)
    random_state = np.random.RandomState(seed)
    data = random_state.permutation(data)
    unique_values = np.unique(data)
    assert np.all(unique_values == np.arange(len(unique_values)))
    num_samples = data.shape[0]
    num_unseen_samples = max(1, int(0.05 * num_samples))
    num_train_samples = num_samples - 2 * num_unseen_samples
    train_data = data[:num_train_samples]
    valid_data = data[num_train_samples : num_train_samples + num_unseen_samples]
    test_data = data[num_train_samples + num_unseen_samples :]
    return train_data, valid_data, test_data


def csv_2_numpy(
    filename: str, path: str = "datasets", sep: str = ",", dtype: str = "int8"
) -> np.ndarray:
    """
    Utility to read a dataset in csv format into a numpy array.
    """
    file_path = os.path.join(path, filename)
    reader = csv.reader(open(file_path), delimiter=sep)
    x = list(reader)
    array = np.array(x, dtype=dtype)
    return array


def load_binary_dataset(
    name: str,
    path: str = "datasets",
    sep: str = ",",
    dtype: str = "int64",
    suffix: str = "data",
    splits: list[str] | None = None,
    verbose: bool = False
) -> list[np.ndarray]:
    """
    Loading training, validation and test splits by suffix from csv files.
    """
    if splits is None:
        splits = ["train", "valid", "test"]
    csv_files = [os.path.join(name, f"{name}.{ext}.{suffix}") for ext in splits]

    load_start_t = time.perf_counter()
    dataset_splits = [csv_2_numpy(file, path, sep, dtype) for file in csv_files]
    load_end_t = time.perf_counter()

    if verbose:
        print(
            "Dataset splits for {} loaded in {} secs".format(
                name, load_end_t - load_start_t
            )
        )
        for data, split in zip(dataset_splits, splits):
            print(f"\t{split}:\t{data.shape}")

    return dataset_splits


def load_categorical_dataset(
    name: str,
    path: str = "datasets",
    sep: str = ",",
    dtype: str = "int64",
    suffix: str = "data",
    splits: list[str] | None = None,
    verbose: bool = False,
    num_variables: int = -1,
    num_categories: int = 2
) -> list[np.ndarray]:
    # Construct a categorical dataset by starting from a binary dataset and by grouping
    # a certain number of binary variables into categorical ones

    dataset_splits = load_binary_dataset(
        name, path=path, sep=sep, dtype=dtype, suffix=suffix, splits=splits, verbose=verbose
    )

    assert num_categories > 1
    pow_of_two = (num_categories & (num_categories - 1) == 0) and num_categories != 0
    assert pow_of_two
    num_group_variables = int(np.log2(num_categories))
    if num_categories != 2:
        assert all(num_variables * num_group_variables <= d.shape[1] for d in dataset_splits)
        exp2values = np.array([2 ** k for k in range(num_group_variables)], dtype=dtype)
        dataset_splits = [d[:, :num_variables * num_group_variables] for d in dataset_splits]
        dataset_splits = [
            np.sum(d.reshape(d.shape[0], num_variables, num_group_variables) * exp2values, axis=2)
            for d in dataset_splits
        ]

    return dataset_splits


def load_image_dataset(name: str, path: str = "datasets") -> tuple[
    tuple[int, int, int],
    tuple[Dataset, Dataset, Dataset],
]:
    if name == "MNIST":
        train_data = MNIST(path, train=True, download=True).data.unsqueeze(dim=-1)
        valid_data = None
        test_data = MNIST(path, train=False, download=True).data.unsqueeze(dim=-1)
        image_shape = (train_data.shape[3], train_data.shape[1], train_data.shape[2])
    elif name == "FashionMNIST":
        train_data = FashionMNIST(path, train=True, download=True).data.unsqueeze(
            dim=-1
        )
        valid_data = None
        test_data = FashionMNIST(path, train=False, download=True).data.unsqueeze(
            dim=-1
        )
        image_shape = (train_data.shape[3], train_data.shape[1], train_data.shape[2])
    elif name == "CIFAR10":
        train_data = CIFAR10(path, train=True, download=True).data
        valid_data = None
        test_data = CIFAR10(path, train=False, download=True).data
        image_shape = (train_data.shape[3], train_data.shape[1], train_data.shape[2])
    elif name == "CelebA":
        train_data = CELEBA(path, split="train", ycc=True)
        valid_data = CELEBA(path, split="valid", ycc=True)
        test_data = CELEBA(path, split="test", ycc=True)
        image_shape = (3, 64, 64)
    else:
        raise ValueError(f"Unknown datasets called {name}")

    if isinstance(train_data, Dataset):
        assert isinstance(valid_data, Dataset)
        assert isinstance(test_data, Dataset)
        return image_shape, (train_data, valid_data, test_data)

    if isinstance(train_data, np.ndarray):
        train_data = torch.from_numpy(train_data)
    train_data = train_data.to(torch.int64)
    if valid_data is not None:
        if isinstance(valid_data, np.ndarray):
            valid_data = torch.from_numpy(valid_data)
    if isinstance(test_data, np.ndarray):
        test_data = torch.from_numpy(test_data)
    test_data = cast(Tensor, test_data).to(torch.int64)
    if valid_data is None:
        train_idx, valid_idx = train_test_split(
            np.arange(train_data.shape[0]),
            test_size=0.05,
            random_state=42,
            shuffle=True,
        )
        valid_data = train_data[valid_idx]
        train_data = train_data[train_idx]
    train_data = TensorDataset(
        train_data.permute(0, 3, 1, 2).flatten(start_dim=1).contiguous()
    )
    valid_data = TensorDataset(
        valid_data.permute(0, 3, 1, 2).flatten(start_dim=1).contiguous()
    )
    test_data = TensorDataset(
        test_data.permute(0, 3, 1, 2).flatten(start_dim=1).contiguous()
    )

    return (
        image_shape,
        (train_data, valid_data, test_data),
    )


def load_continuous_dataset(
    name: str, path: str = "datasets", dtype: np.dtype = np.float32
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    if name == "power":
        data = POWER(path)
        data_train, data_valid, data_test = data.trn.x, data.val.x, data.tst.x
    elif name == "gas":
        data = GAS(path)
        data_train, data_valid, data_test = data.trn.x, data.val.x, data.tst.x
    elif name == "hepmass":
        data = HEPMASS(path)
        data_train, data_valid, data_test = data.trn.x, data.val.x, data.tst.x
    elif name == "miniboone":
        data = MINIBOONE(path)
        data_train, data_valid, data_test = data.trn.x, data.val.x, data.tst.x
    elif name == "bsds300":
        data = BSDS300(path)
        data_train, data_valid, data_test = data.trn.x, data.val.x, data.tst.x
    else:
        raise ValueError(f"Unknown continuous dataset called {name}")

    data_train = data_train.astype(dtype, copy=False)
    data_valid = data_valid.astype(dtype, copy=False)
    data_test = data_test.astype(dtype, copy=False)
    return data_train, data_valid, data_test


def load_artificial_dataset(
    name: str,
    num_samples: int,
    valid_test_perc: float = 0.2,
    seed: int = 42,
    dtype: np.dtype = np.float32,
    discretize: bool = False,
    discretize_unique: bool = False,
    discretize_bins: int = 32,
    shuffle_bins: bool = False,
    **kwargs,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    num_valid_samples = int(num_samples * valid_test_perc * 0.5)
    num_test_samples = int(num_samples * valid_test_perc)
    total_num_samples = num_samples + num_valid_samples + num_test_samples
    if name == "ring":
        data = single_ring_sample(total_num_samples, seed=seed, **kwargs)
    elif name == "mring":
        data = multi_rings_sample(total_num_samples, seed=seed, **kwargs)
    elif name == "funnel":
        data = funnel_sample(total_num_samples, seed=seed, **kwargs)
        data = rotate_samples(data)
    elif name == "banana":
        data = banana_sample(total_num_samples, seed=seed, **kwargs)
    elif name == "cosine":
        data = cosine_sample(total_num_samples, seed=seed, **kwargs)
        data = rotate_samples(data)
    elif name == "spiral":
        data = spiral_sample(total_num_samples, seed=seed, **kwargs)
    elif name == "crossing-rings":
        data = crossing_rings_sample(total_num_samples, seed=seed, **kwargs)
    else:
        raise ValueError(f"Unknown dataset called {name}")
    data = data.astype(dtype=dtype, copy=False)

    if discretize:
        # Standardize data before "quantizing" it
        data = (data - np.mean(data, axis=0)) / (np.std(data, axis=0) + 1e-10)
        xlim, ylim = (np.min(data[:, 0]), np.max(data[:, 0])), (
            np.min(data[:, 1]),
            np.max(data[:, 1]),
        )
        _, xedges, yedges = np.histogram2d(
            data[:, 0], data[:, 1], bins=discretize_bins, range=[xlim, ylim]
        )
        quantized_xdata = np.searchsorted(xedges[:-1], data[:, 0], side="right") - 1
        quantized_ydata = np.searchsorted(yedges[:-1], data[:, 1], side="right") - 1
        if shuffle_bins:
            perm_state = np.random.RandomState(seed)
            state_permutation = perm_state.permutation(discretize_bins)
            quantized_xdata = state_permutation[quantized_xdata]
            quantized_ydata = state_permutation[quantized_ydata]
        data = np.stack([quantized_xdata, quantized_ydata], axis=1)
        if discretize_unique:
            data = np.unique(data, axis=0)
            num_samples = len(data)
            valid_test_perc *= 0.5
            num_valid_samples = int(num_samples * valid_test_perc * 0.5)
            num_test_samples = int(num_samples * valid_test_perc)

    train_data, valid_test_data = train_test_split(
        data,
        test_size=num_valid_samples + num_test_samples,
        shuffle=True,
        random_state=seed,
    )
    valid_data, test_data = train_test_split(
        valid_test_data, test_size=num_test_samples, shuffle=False
    )
    return train_data, valid_data, test_data
