from typing import Dict, List, Tuple

from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, Compose, Normalize, RandomCrop, RandomHorizontalFlip
import numpy as np
from typing import List, Optional, Tuple
import torch
from omegaconf import DictConfig
from torch.utils.data import DataLoader, random_split, Dataset
import torch.utils.data as data
import h5py
import os
import math, os, random

from utils_helper import letter_to_vec, word_to_indices

"""Partition the data and create the dataloaders."""
from dataset_preparation import (
    partition_data,
    partition_data_dirichlet,
    partition_data_label_quantity,
    _partition_data,
    split_train_validation_test_clients,
)

def get_dataset():
  transform_train = Compose([
    RandomCrop(32, padding=4),
    RandomHorizontalFlip(),
    ToTensor(),
    Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  ])

  transform_test = Compose([
    ToTensor(),
    Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  ])

  trainset = datasets.CIFAR10(root='./data', train=True, download=True,
                transform=transform_train)
  testset = datasets.CIFAR10(root='./data', train=False, download=True,
                 transform=transform_test)
  return trainset, testset


def prepare_dataset(num_part:int,batch_size:int, val_ratio = 0.1, seed = 0):

  trainset, testset = get_dataset()


# pylint: disable=too-many-locals, too-many-branches
def load_datasets(
    config: DictConfig,
    num_clients: int,
    val_ratio: float = 0.1,
    seed: Optional[int] = 42,
) -> Tuple[List[DataLoader], List[DataLoader], DataLoader]:
    """Create the dataloaders to be fed into the model.

    Parameters
    ----------
    config: DictConfig
        Parameterises the dataset partitioning process
    num_clients : int
        The number of clients that hold a part of the data
    val_ratio : float, optional
        The ratio of training data that will be used for validation (between 0 and 1),
        by default 0.1
    seed : int, optional
        Used to set a fix seed to replicate experiments, by default 42

    Returns
    -------
    Tuple[DataLoader, DataLoader, DataLoader]
        The DataLoaders for training, validation, and testing.
    """
    print(f"Dataset partitioning config: {config}")
    partitioning = ""
    if "partitioning" in config:
        partitioning = config.partitioning
    # partition the data
    if partitioning == "dirichlet":
        alpha = 0.5
        if "alpha" in config:
            alpha = config.alpha
        datasets, testset = partition_data_dirichlet(
            num_clients,
            alpha=alpha,
            seed=seed,
            dataset_name=config.name,
            model_type=config.model_type if "model_type" in config else None,
        )
    elif partitioning == "label_quantity":
        labels_per_client = 2
        if "labels_per_client" in config:
            labels_per_client = config.labels_per_client
        datasets, testset = partition_data_label_quantity(
            num_clients,
            labels_per_client=labels_per_client,
            seed=seed,
            dataset_name=config.name,
        )
    elif partitioning == "iid":
        datasets, testset = partition_data(
            num_clients,
            similarity=1.0,
            seed=seed,
            dataset_name=config.name,
            model_type=config.model_type if "model_type" in config else None,
        )
    elif partitioning == "iid_noniid":
        similarity = 0.5
        if "similarity" in config:
            similarity = config.similarity
        datasets, testset = partition_data(
            num_clients,
            similarity=similarity,
            seed=seed,
            dataset_name=config.name,
        )

    batch_size = -1
    if "batch_size" in config:
        batch_size = config.batch_size
    elif "batch_size_ratio" in config:
        batch_size_ratio = config.batch_size_ratio
    else:
        raise ValueError

    # split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for dataset in datasets:
        len_val = int(len(dataset) / (1 / val_ratio)) if val_ratio > 0 else 0
        lengths = [len(dataset) - len_val, len_val]
        ds_train, ds_val = random_split(
            dataset, lengths, torch.Generator().manual_seed(seed)
        )
        if batch_size == -1:
            batch_size = int(len(ds_train) * batch_size_ratio)
        trainloaders.append(DataLoader(ds_train, batch_size=batch_size, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=batch_size))
    return trainloaders, valloaders, DataLoader(testset, batch_size=128)


class ShakespeareDataset(Dataset):
    """
    [LEAF: A Benchmark for Federated Settings](https://github.com/TalwalkarLab/leaf).

    We imported the preprocessing method for the Shakespeare dataset from GitHub.

    word_to_indices : returns a list of character indices
    sentences_to_indices: converts an index to a one-hot vector of a given size.
    letter_to_vec : returns one-hot representation of given letter

    """

    def __init__(self, data):
        sentence, label = data["x"], data["y"]
        sentences_to_indices = [word_to_indices(word) for word in sentence]
        sentences_to_indices = np.array(sentences_to_indices)
        self.sentences_to_indices = np.array(sentences_to_indices, dtype=np.int64)
        self.labels = np.array(
            [letter_to_vec(letter) for letter in label], dtype=np.int64
        )

    def __len__(self):
        """Return the number of labels present in the dataset.

        Returns
        -------
            int: The total number of labels.
        """
        return len(self.labels)

    def __getitem__(self, index):
        """Retrieve the data and its corresponding label at a given index.

        Args:
            index (int): The index of the data item to fetch.

        Returns
        -------
            tuple: (data tensor, label tensor)
        """
        data, target = self.sentences_to_indices[index], self.labels[index]
        return torch.tensor(data), torch.tensor(target)


class FemnistDataset(Dataset):
    """
    [LEAF: A Benchmark for Federated Settings](https://github.com/TalwalkarLab/leaf).

    We imported the preprocessing method for the Femnist dataset from GitHub.
    """

    def __init__(self, dataset, transform):
        self.x = dataset["x"]
        self.y = dataset["y"]
        self.transform = transform

    def __getitem__(self, index):
        """Retrieve the input data and its corresponding label at a given index.

        Args:
            index (int): The index of the data item to fetch.

        Returns
        -------
            tuple:
                - input_data (torch.Tensor): Reshaped and optionally transformed data.
                - target_data (int or torch.Tensor): Label for the input data.
        """
        input_data = np.array(self.x[index]).reshape(28, 28)
        if self.transform:
            input_data = self.transform(input_data)
        target_data = self.y[index]
        return input_data.to(torch.float32), target_data

    def __len__(self):
        """Return the number of labels present in the dataset.

        Returns
        -------
            int: The total number of labels.
        """
        return len(self.y)