
import numpy as np
rng = np.random.default_rng()
import torch, torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

from flwr_datasets import FederatedDataset

import isik_partitioner as isik

CPUS_PER_EXPERIMENT = 11

def load_datasets(wandb_config):
    train_transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            # The actual values seemed to decrease my training accuracy.
            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

        ]
    )
    num_clients = wandb_config['num_clients']
    batch_size = wandb_config['batch_size']
    
    if wandb_config['non_iid'] == 'iid':
        if wandb_config['transforms']:
            raise NotImplementedError("Not done atm for new transforms.")
        fds = FederatedDataset(dataset="cifar10", partitioners={"train": wandb_config['num_clients']})

        def apply_transforms(batch):
            # Some weird Flower thing to do it differently
            batch["img"] = [test_transform(img) for img in batch["img"]]
            return batch

        trainloaders = []
        val_partitions = []

        for partition_id in range(num_clients):
            partition = fds.load_partition(partition_id, "train")
            partition = partition.with_transform(apply_transforms)
            # partition = partition.train_test_split(train_size=1.0)
            # trainloaders.append(DataLoader(partition["train"], batch_size=batch_size, shuffle=True))
            trainloaders.append(DataLoader(partition, batch_size=batch_size, shuffle=True))
            # val_partitions.append(partition['test'])

        # valloader = DataLoader(torch.utils.data.ConcatDataset(val_partitions), batch_size=batch_size)
        testset = fds.load_full("test").with_transform(apply_transforms)
        testloader = DataLoader(testset, batch_size=batch_size)

    elif wandb_config['non_iid'] == 'isik':
        trainloaders = isik.get_data_loaders(num_clients, batch_size, 
                                             classes_pc=int(wandb_config['non_iid_param']))
    elif wandb_config['non_iid'] == 'dirichlet':
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform if wandb_config['transforms'] else test_transform)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

        #fedlab adds really annoying logging. Quick fix: Only import when necessary.
        from fedlab.utils.dataset.partition import CIFAR10Partitioner
        
        unbalance_dir_part = CIFAR10Partitioner(trainset.targets,
                                        wandb_config['num_clients'],
                                        balance=False,
                                        partition="dirichlet",
                                        unbalance_sgm=0.3,
                                        dir_alpha=wandb_config['non_iid_param'],
                                        seed=42 if 'seed' not in wandb_config else wandb_config['seed'])
        
        trainloaders = [DataLoader(torch.utils.data.Subset(trainset, indices), batch_size=wandb_config['batch_size'], shuffle=True) for indices in unbalance_dir_part.client_dict.values()]
        testloader = DataLoader(testset, batch_size=wandb_config['batch_size'], shuffle=False)

    valloader = testloader
    return trainloaders, valloader, testloader
    
def test(wandb_config, net, testloader, device):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    
    # Try "cpu" or "cuda" to train on GPU
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            if wandb_config['non_iid'] != 'dirichlet':
                images, labels = batch["img"].to(device), batch["label"].to(device)
            else:
                
                images, labels = batch
                images = images.to(device, dtype=torch.float64)
                labels = labels.to(device, dtype=torch.long)

            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        loss /= len(testloader.dataset)
        accuracy = correct / total
    return loss, accuracy