"""Dataset creation and partitioning module for federated learning."""

import os
import json
import random
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from logging import INFO
from typing import List, Tuple, Dict
from collections import defaultdict
from sklearn import preprocessing
from PIL import Image
from flwr.common.logger import log
from torch.utils.data import (
    DataLoader,
    Dataset,
    Subset,
    RandomSampler,
    BatchSampler,
    ConcatDataset,
)
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST, SVHN
from .utils import init_data
from dataset.partition_noniid import (
    get_cifar_10,
    get_svhn,
    get_fashion_mnist,
    get_mnist,
    create_lda_partitions,
    cifar10Transformation,
    svhnTransformation,
    fmnistTransformation,
    mnistTransformation,
)
from dataset.nist_preprocessor import NISTPreprocessor
from dataset.nist_sampler import NistSampler
from dataset.nist_dataset import create_dataset, create_partition_list
from dataset.zip_downloader import ZipDownloader
from dataset.language_utils import word_to_indices, letter_to_vec

# Import the WrapDataset class and create_noniid_data function
from dataset.wrap_dataset import WrapDataset
from dataset.federated import create_noniid_data


def read_dir(data_dir):
    clients = []
    groups = []
    data = defaultdict(lambda: None)

    files = os.listdir(data_dir)
    files = [f for f in files if f.endswith('.json')]
    for f in files:
        file_path = os.path.join(data_dir, f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
        data.update(cdata['user_data'])

    clients = list(sorted(data.keys()))
    return clients, groups, data


def read_data(train_data_dir, test_data_dir):
    '''Parses data in given train and test data directories.'''
    train_clients, train_groups, train_data = read_dir(train_data_dir)
    test_clients, test_groups, test_data = read_dir(test_data_dir)

    assert train_clients == test_clients
    assert train_groups == test_groups

    return train_clients, train_groups, train_data, test_data


class CELEBA_Client(Dataset):
    """CELEBA Dataset Client."""

    def __init__(self, dataclient, train=True, method='base', transform=None, target_transform=None, root="~/Documents/datasets/leaf/data/celeba/data"):
        super().__init__()
        if transform is None:
            self.transform = transforms.Compose([
                transforms.CenterCrop((178, 178)),
                transforms.Resize((128, 128)),
                transforms.ToTensor(),
            ])
        else:
            self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.method = method
        self.root = os.path.expanduser(root)

        if self.train:
            self.data = []
            self.label = []
            cur_x = dataclient['x']
            cur_y = dataclient['y']
            for j in range(len(cur_x)):
                img_path = os.path.join(self.root, 'raw', 'img_align_celeba', cur_x[j])
                img = np.asarray(Image.open(img_path))
                self.data.append(img)
                self.label.append(cur_y[j])
        else:
            self.data = []
            self.label = []
            cur_x = dataclient['x']
            cur_y = dataclient['y']
            for j in range(len(cur_x)):
                img_path = os.path.join(self.root, 'raw', 'img_align_celeba', cur_x[j])
                img = np.asarray(Image.open(img_path))
                self.data.append(img)
                self.label.append(cur_y[j])

    def __getitem__(self, index):
        img, target = self.data[index], self.label[index]
        img = Image.fromarray(img, mode='RGB')
        if self.transform is not None:
            img = self.transform(img)
        target = torch.tensor(target, dtype=torch.long)
        if self.method == 'FATS':
            return img, target, index
        else:
            return img, target

    def __len__(self):
        return len(self.data)


class SHAKESPEARE_Client(Dataset):
    """Shakespeare Dataset Client."""

    def __init__(self, dataclient, train=True, method='base'):
        super().__init__()
        self.train = train
        self.method = method
        if self.train:
            self.data = dataclient['x']
            self.label = dataclient['y']
        else:
            self.data = dataclient['x']
            self.label = dataclient['y']

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sentence, target = self.data[index], self.label[index]
        indices = word_to_indices(sentence)
        target = letter_to_vec(target)
        indices = torch.tensor(indices, dtype=torch.long)
        target = torch.tensor(target, dtype=torch.float)
        if self.method == 'FATS':
            return indices, target, index
        else:
            return indices, target


class DatasetwithIndex(torch.utils.data.Dataset):
    """Dataset wrapper to include indices."""

    def __init__(self, source):
        self.source = source

    def __getitem__(self, index):
        data, target = self.source[index]
        return data, target, index

    def __len__(self):
        return len(self.source)


def load_vision_data(dataname: str = "mnist", method: str = 'base'):
    """Load vision datasets (MNIST, CIFAR-10, etc.)."""
    if dataname == "cifar10":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                                 0.229, 0.224, 0.225]),
        ])
        trainset = CIFAR10("~/Documents/datasets", train=True,
                           download=True, transform=transform)
        testset = CIFAR10("~/Documents/datasets", train=False,
                          download=True, transform=transform)
    elif dataname == "cifar100":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409),
                                 (0.2673, 0.2564, 0.2761)),
        ])
        trainset = CIFAR100("~/Documents/datasets", train=True,
                            download=True, transform=transform)
        testset = CIFAR100("~/Documents/datasets", train=False,
                           download=True, transform=transform)
    elif dataname == "mnist":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        trainset = MNIST("~/Documents/datasets", train=True,
                         download=True, transform=transform)
        testset = MNIST("~/Documents/datasets", train=False,
                        download=True, transform=transform)
    elif dataname == "fmnist":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        trainset = FashionMNIST("~/Documents/datasets", train=True,
                                download=True, transform=transform)
        testset = FashionMNIST("~/Documents/datasets", train=False,
                               download=True, transform=transform)
    elif dataname == "SVHN":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)),
        ])
        trainset = SVHN("~/Documents/datasets", split='train',
                        download=True, transform=transform)
        testset = SVHN("~/Documents/datasets", split='test',
                       download=True, transform=transform)
    else:
        raise ValueError(f"Unsupported dataset: {dataname}")

    if method == 'FATS':
        trainset = DatasetwithIndex(trainset)
        testset = DatasetwithIndex(testset)

    num_examples = {"trainset": len(trainset), "testset": len(testset)}
    classes = getattr(trainset, 'classes', None)
    datashape = trainset[0][0].shape
    return trainset, testset, num_examples


def transform_datasets_into_dataloaders(
    datasets: List[Dataset], batch_size, random_seed=None, **dataloader_kwargs
) -> List[DataLoader]:
    """
    Transform datasets into DataLoaders with fixed random seed for shuffling.

    Parameters
    ----------
    datasets: List[Dataset]
        List of datasets.
    batch_size: int
        Batch size for DataLoader.
    random_seed: int
        Random seed for shuffling.
    dataloader_kwargs
        Additional arguments to DataLoader.

    Returns
    -------
    dataloaders: List[DataLoader]
        List of DataLoaders.
    """
    dataloaders = []
    # Create a torch.Generator with the specified random_seed
    if random_seed is not None:
        generator = torch.Generator()
        generator.manual_seed(random_seed)
    else:
        generator = None

    for dataset in datasets:
        sampler = RandomSampler(
            dataset,
            replacement=False,
            generator=generator,  # Pass the generator to the sampler
        )
        batch_sampler = BatchSampler(
            sampler, batch_size=batch_size, drop_last=False
        )
        dataloaders.append(DataLoader(
            dataset,
            batch_sampler=batch_sampler,
            **dataloader_kwargs
        ))
    return dataloaders


def create_federated_dataloaders(
    config,
    dataset: str = "mnist",
    sampling_type: str = 'niid',
    dataset_fraction: float = 0.05,
    batch_size: int = 10,
    train_fraction: float = 0.6,
    validation_fraction: float = 0.2,
    test_fraction: float = 0.2,
    random_seed: int = 42,
    method: str = "base",
    min_samples_per_client: int = 0,
) -> Tuple[Dict[int, DataLoader], Dict[int, DataLoader]]:
    """Create federated DataLoaders for the specified dataset."""

    # Load the dataset
    # train_set, test_set, _ = load_vision_data(dataname=dataset, method=method)

    # # Prepare data in the required format
    # train_data = {'x': [data[0] for data in train_set], 'y': [data[1] for data in train_set]}
    # test_data = {'x': [data[0] for data in test_set], 'y': [data[1] for data in test_set]}

    data = init_data(config)

    # Number of clients
    num_clients = config.federated.num_clients

    # Create non-IID data partitions
    client_train_sets = create_noniid_data(data.train_set,  train=True, config=config)
    client_test_sets = create_noniid_data(data.test_set,  train=False, config=config)

    # Create DataLoaders with fixed random seed
    trainloaders = transform_datasets_into_dataloaders(
        client_train_sets, batch_size=batch_size, random_seed=random_seed
    )
    # valloaders = transform_datasets_into_dataloaders(
    #     partitioned_validation, batch_size=batch_size, random_seed=random_seed
    # )
    testloaders = transform_datasets_into_dataloaders(
        client_test_sets, batch_size=batch_size, random_seed=random_seed
    )

    # Create a central test DataLoader
    central_test = torch.utils.data.ConcatDataset(client_test_sets)
    central_testloader = DataLoader(
        central_test, batch_size=batch_size, shuffle=False
    )

    log(INFO, "Creation of the partitioned DataLoaders is done.")
    return trainloaders, testloaders, testloaders, central_testloader


def cifar10_train_valid_test_partition_selected(
    num_clients: int,
    validation_fraction: float = 0.1,
    sampling_type: str = "niid",
    random_seed: int = None,
    method: str = "base",
    dataname: str = "cifar10",
    selected_clients: List[int] = []
) -> Tuple[List[Dataset], List[Dataset], List[Dataset]]:
    """Partition CIFAR-10 dataset among selected clients."""
    alpha = 0.5
    if dataname == "cifar10":
        train_path, testset = get_cifar_10()
    elif dataname == "SVHN":
        train_path, testset = get_svhn()
        alpha = 10
    elif dataname == "fmnist":
        train_path, testset = get_fashion_mnist()
    elif dataname == "mnist":
        train_path, testset = get_mnist()

    if method == "FATS":
        testset = DatasetwithIndex(testset)

    train_data, labels = torch.load(train_path)
    idx = np.array(range(len(labels)))
    dataset = [idx, labels]
    partitions, _ = create_lda_partitions(
        dataset, num_partitions=num_clients, concentration=alpha, accept_imbalanced=True
    )

    partitioned_train = []
    partitioned_validation = []
    partitioned_test = []
    test_len = len(testset) // num_clients

    for cid in selected_clients:
        images_partition = train_data[partitions[cid][0]]
        labels_partition = partitions[cid][1]
        if dataname == "cifar10":
            train_partition = CIFAR10_Client(
                images_partition, labels_partition, method=method, transform=cifar10Transformation()
            )
        elif dataname == "SVHN":
            train_partition = CIFAR10_Client(
                images_partition, labels_partition, method=method, transform=svhnTransformation()
            )
        elif dataname == "fmnist":
            train_partition = CIFAR10_Client(
                images_partition, labels_partition, method=method, transform=fmnistTransformation()
            )
        elif dataname == "mnist":
            train_partition = CIFAR10_Client(
                images_partition, labels_partition, method=method, transform=mnistTransformation()
            )

        train_len = len(train_partition)
        val_len = int(train_len * validation_fraction)
        train_len = train_len - val_len

        validation_subset = Subset(train_partition, range(val_len))
        train_subset = Subset(train_partition, range(val_len, val_len + train_len))
        test_subset = Subset(testset, range(cid * test_len, (cid + 1) * test_len))

        partitioned_train.append(train_subset)
        partitioned_validation.append(validation_subset)
        partitioned_test.append(test_subset)

    return partitioned_train, partitioned_validation, partitioned_test


def shakespeare_train_valid_test_partition(
    num_clients: int,
    validation_fraction: float = 0.1,
    sampling_type: str = "iid",
    random_seed: int = None,
    method: str = "base",
    root: str = "./data/shakespeare/data",
) -> Tuple[List[Dataset], List[Dataset], List[Dataset]]:
    """Partition Shakespeare dataset among clients."""
    train_clients, _, train_data_temp, test_data_temp = read_data(os.path.join(root, "train"), os.path.join(root, "test"))
    partitioned_train = []
    partitioned_validation = []
    partitioned_test = []

    for cid in range(num_clients):
        train_partition = SHAKESPEARE_Client(
            train_data_temp[train_clients[cid]], train=True, method=method
        )
        train_len = len(train_partition)
        val_len = int(train_len * validation_fraction)
        train_len = train_len - val_len

        validation_subset = Subset(train_partition, range(val_len))
        train_subset = Subset(train_partition, range(val_len, val_len + train_len))
        test_subset = SHAKESPEARE_Client(
            test_data_temp[train_clients[cid]], train=False, method=method
        )
        partitioned_train.append(train_subset)
        partitioned_validation.append(validation_subset)
        partitioned_test.append(test_subset)

    return partitioned_train, partitioned_validation, partitioned_test
