from datasets import load_dataset
import json 
import numpy as np
import random
from scipy import stats
from sklearn.datasets import fetch_20newsgroups
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

DATA = "./"

def custom_collate_fn(batch):
    # Simply return the batch without trying to stack elements
    return batch

def build_dataset(dataset, batch_size, n_clients, alpha=-1, seed=0):
    collate_fn = None

    if dataset.startswith("glue_"):
        task = dataset.split("_", 1)[1]
        clients, testset   = build_glue(task, n_clients, seed)
    else: 
        if dataset == 'cifar10':
            clients, testset = build_cifar10(n_clients, alpha, seed)
            collate_fn = None
        elif dataset == 'cifar100': 
            clients, testset = build_cifar100(n_clients, alpha, seed)
            collate_fn = None
        elif dataset == 'svhn': 
            clients, testset = build_svhn(n_clients, alpha, seed)
            collate_fn = None
        elif dataset == '20newsgroups': 
            clients, testset = build_20newsgroups(n_clients, alpha, seed)
            collate_fn = None
        elif dataset == 'mrqa': 
            clients, testset = build_mrqa(n_clients, alpha, seed)
            # Use custom collate function for MRQA since the dataset returns dictionaries.
            collate_fn = custom_collate_fn
        else:
            raise NotImplementedError(f"Dataset {dataset} is not implemented.")
    
    total_train = sum(len(client) for client in clients)
    print(f"Total train examples: {total_train:,}")
    print(f"Test-set examples   : {len(testset):,}")

    if collate_fn is not None:
        clientloaders = [DataLoader(client, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0) for client in clients]
        testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
    else:
        clientloaders = [DataLoader(client, batch_size=batch_size, shuffle=True, num_workers=0) for client in clients]
        testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    return clientloaders, testloader

def partition_dataset(dataset, Y, n_classes, n_clients, alpha, seed):
    clients = []
    # IID Case
    if alpha == -1:
        N = len(dataset)
        rand_idx = torch.randperm(N)
        NC = N // n_clients
        for i in range(n_clients):
            idx = rand_idx[NC*i:NC*(i+1)]
            subset = torch.utils.data.Subset(dataset, idx)
            clients.append(subset)
    # NIID Case
    else:
        ex_per_class = len(dataset) // n_classes

        # Create per client example distribution using Dirichlet distribution 
        rv_tr = stats.dirichlet.rvs(np.repeat(alpha, n_classes), size=n_clients, random_state=seed) 
        rv_tr = rv_tr / rv_tr.sum(axis=0)
        rv_tr = (rv_tr*ex_per_class).round().astype(int)

        # Dictionary of example indices with corresponding classes 
        class_to_idx = {i: np.where(Y == i)[0] for i in range(n_classes)}
        curr_start = np.zeros(n_classes).astype(int)

        # Generate client datasets 
        for client_classes in rv_tr:
            curr_end = curr_start + client_classes
            client_idx = np.concatenate([class_to_idx[c][curr_start[c]:curr_end[c]] for c in range(n_classes)])
            curr_start = curr_end
            clients.append(torch.utils.data.Subset(dataset, client_idx))
    return clients

""" Vision Tasks """

def build_cifar10(n_clients, alpha, seed):
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    test_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        normalize,
    ])
    trainset = torchvision.datasets.CIFAR10(root=f"{DATA}/cifar10", train=True,
                                        download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root=f"{DATA}/cifar10", train=False,
                                        download=True, transform=test_transform)

    Y = np.array(trainset.targets)
    n_classes = 10
    clients = partition_dataset(trainset, Y, n_classes, n_clients, alpha, seed)

    # Can also partition testset if want to have test clients instead of 1 universal test set
    return clients, testset

def build_cifar100(n_clients, alpha, seed):
    # Normalization values for CIFAR100 (mean and std computed over the dataset)
    normalize = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        normalize,
    ])
    
    # Create the training and test sets for CIFAR100
    trainset = torchvision.datasets.CIFAR100(
        root=f"{DATA}/cifar100", train=True, download=True, transform=transform
    )
    testset = torchvision.datasets.CIFAR100(
        root=f"{DATA}/cifar100", train=False, download=True, transform=test_transform
    )
    
    # Extract targets from the training set and set the number of classes to 100
    Y = np.array(trainset.targets)
    n_classes = 100
    clients = partition_dataset(trainset, Y, n_classes, n_clients, alpha, seed)
    
    # Optionally, you can also partition the testset instead of having a single universal test set.
    return clients, testset

def build_svhn(n_clients, alpha, seed):
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    test_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        normalize,
    ])
    
    # Download and build training and test sets
    trainset = torchvision.datasets.SVHN(
        root=f"{DATA}/svhn", 
        split='train', 
        download=True, 
        transform=transform
    )
    testset = torchvision.datasets.SVHN(
        root=f"{DATA}/svhn", 
        split='test', 
        download=True, 
        transform=test_transform
    )
    
    # Extract labels (SVHN uses .labels instead of .targets)
    Y = trainset.labels
    n_classes = 10
    
    # Partition the dataset into n_clients with Dirichlet( alpha )
    clients = partition_dataset(trainset, Y, n_classes, n_clients, alpha, seed)
    
    return clients, testset

""" Language Tasks """

class TwentyNewsGroupsDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __getitem__(self, idx):
        text = self.data[idx]
        label = self.targets[idx]

        return text, label

    def __len__(self):
        return len(self.data)

def build_20newsgroups(n_clients, alpha, seed):
    # Load 20 Newsgroups training and test sets
    train_data = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
    test_data = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))

    # Create torch datasets
    trainset = TwentyNewsGroupsDataset(train_data.data, train_data.target)
    testset = TwentyNewsGroupsDataset(test_data.data, test_data.target)

    # Partition the training set using your existing partitioning function.
    # There are 20 classes in the 20 Newsgroups dataset.
    Y = np.array(train_data.target)
    n_classes = 20
    clients = partition_dataset(trainset, Y, n_classes, n_clients, alpha, seed)

    return clients, testset

class MRQADataset(Dataset):
    def __init__(self, examples):
        self.examples = examples

    def __getitem__(self, idx):
        # Returns the complete example dictionary.
        return self.examples[idx]

    def __len__(self):
        return len(self.examples)

def build_mrqa(n_clients, alpha, seed):
    train_file = f"{DATA}/MRQA/train.json"
    test_file = f"{DATA}/MRQA/test.json"

    # Load training and test examples from the JSON files.
    with open(train_file, 'r') as f:
        train_examples = json.load(f)
    with open(test_file, 'r') as f:
        test_examples = json.load(f)

    random.seed(seed)
    random.shuffle(train_examples)
    random.shuffle(test_examples)

    # Create dataset objects for training and testing.
    trainset = MRQADataset(train_examples)
    testset = MRQADataset(test_examples)

    # Partition the training set based on the 'subtask' field.
    Y = np.array([ex['subtask'] for ex in train_examples])
    n_subtasks = 6  # There are 6 subtasks in MRQA.
    clients = partition_dataset(trainset, Y, n_subtasks, n_clients, alpha, seed)

    return clients, testset

""" GLUE Tasks """ 

PAIR_KEYS = {
    "mnli"  : ("premise", "hypothesis"),
    # the two “pseudo‑tasks” map to the canonical mnli column names
    "mnli_matched"   : ("premise", "hypothesis"),
    "mnli_mismatched": ("premise", "hypothesis"),
    "ax"   : ("premise", "hypothesis"),   # diagnostic set (eval only – no train split)
    "qqp"  : ("question1", "question2"),
    "qnli" : ("question",  "sentence"),
    # everything else that needs pairs falls back to ("sentence1", "sentence2")
}

class GlueDataset(Dataset):
    """
    Wrap a HF GLUE split so that __getitem__ returns (text, label).
    For STS‑B the label is a float; elsewhere it is an int.
    """
    def __init__(self, hf_split, task):
        self.ds   = hf_split
        self.task = task.lower()

        self.pair_cols = PAIR_KEYS.get(self.task, ("sentence1", "sentence2"))
        self.is_pair   = self.task not in ["cola", "sst2"]
        self.is_reg    = self.task == "stsb"

    def __getitem__(self, idx):
        # DataLoader may give us tensor indices – convert to Python int
        if torch.is_tensor(idx):
            idx = idx.item()

        row = self.ds[int(idx)]

        if self.is_pair:
            c1, c2 = self.pair_cols
            text = row[c1] + " [SEP] " + row[c2]
        else:
            text = row["sentence"]

        label = float(row["label"]) if self.is_reg else int(row["label"])
        return text, label

    def __len__(self):
        return len(self.ds)

def build_glue(task: str, n_clients: int, seed: int):
    """
    Return (clients, dev_set) for a GLUE task.
    For mnli_matched/mismatched we load the canonical “mnli” builder
    and just override which validation split we use.
    """
    task = task.lower()

    # Tasks that reuse the MNLI builder
    if task in {"mnli_matched", "mnli_mismatched"}:
        canonical_task = "mnli"
    else:
        canonical_task = task

    # Load the HF dataset *once*
    dataset = load_dataset("glue", canonical_task)

    # Decide which dev/validation split to evaluate on
    if task == "mnli_matched":
        dev_split = "validation_matched"
    elif task == "mnli_mismatched":
        dev_split = "validation_mismatched"
    else:
        # generic logic for all other tasks
        if "validation" in dataset:
            dev_split = "validation"
        elif "validation_matched" in dataset:
            dev_split = "validation_matched"
        else:
            # fall back to the first split that starts with "validation"
            dev_split = next(k for k in dataset if k.startswith("validation"))

    # Guard against tasks that have *no* training split (e.g. AX)
    if "train" not in dataset:
        raise ValueError(
            f"The GLUE task “{task}” has no training split – it is evaluation‑only."
        )

    trainset = GlueDataset(dataset["train"],      task)
    testset  = GlueDataset(dataset[dev_split],    task)

    clients  = partition_dataset(trainset, None, None,
                                 n_clients, -1, seed)
    return clients, testset