import logging
import os
import random

import numpy as np
from torch.utils.data import Subset
from torchvision.datasets import CIFAR10, ImageFolder, CIFAR100, Places365, SVHN
from torchvision.transforms import transforms

from src.data.dataset.load_cifar10_corrupted import CIFAR10_CORRUPT, CIFAR100_CORRUPT


def create_dirichlet_distribution(alpha: float, num_client: int, num_class: int, seed: int):
    random_number_generator = np.random.default_rng(seed)
    distribution = random_number_generator.dirichlet(np.repeat(alpha, num_client), size=num_class).transpose()
    distribution /= distribution.sum()
    return distribution


def split_by_distribution(targets, distribution):
    num_client, num_class = distribution.shape[0], distribution.shape[1]
    sample_number = np.floor(distribution * len(targets))
    class_idx = {class_label: np.where(targets == class_label)[0] for class_label in range(num_class)}

    idx_start = np.zeros((num_client + 1, num_class), dtype=np.int32)
    for i in range(0, num_client):
        idx_start[i + 1] = idx_start[i] + sample_number[i]

    client_samples = {idx: {} for idx in range(num_client)}
    for client_idx in range(num_client):
        samples_idx = np.array([], dtype=np.int32)
        for class_label in range(num_class):
            start, end = idx_start[client_idx, class_label], idx_start[client_idx + 1, class_label]
            samples_idx = (np.concatenate((samples_idx, class_idx[class_label][start:end].tolist())).astype(np.int32))
        client_samples[client_idx] = samples_idx

    return client_samples


def load_ood_dataset(dataset_path: str, ood_dataset: str):
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]

    if ood_dataset == 'LSUN_C':
        ood_data = ImageFolder(
            root=os.path.join(dataset_path, 'LSUN'),
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
                transforms.RandomCrop(32, padding=4)
            ])
        )
    elif ood_dataset == 'lsun_r':
        ood_data = ImageFolder(
            root=os.path.join(dataset_path, 'LSUN_resize'),
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        )
    elif ood_dataset == 'dtd':
        ood_data = ImageFolder(
            root=os.path.join(dataset_path, 'dtd/images'),
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        )
    elif ood_dataset == 'isun':
        ood_data = ImageFolder(
            root=os.path.join(dataset_path, 'iSUN'),
            transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
        )
    elif ood_dataset == 'place365':
        ood_data = Places365(
            root=dataset_path,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        )
    elif ood_dataset == 'SVHN':
        ood_data = SVHN(
            root=dataset_path,
            split='test',
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]),
            download=True
        )
    else:
        raise NotImplementedError('out of distribution dataset should be LSUN_C, dtd, isun')

    return ood_data


def dirichlet_load_test(dataset_path, id_dataset, num_client, alpha, corrupt_list, seed):
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]
    if (id_dataset == 'cifar10') or (id_dataset == 'cifar10_fourier_aug'):
        trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
        test_data = CIFAR10(root=dataset_path, download=True, train=False, transform=trans)
        cor_test = []
        for idx, cor_type in enumerate(corrupt_list):
            cor_test.append(CIFAR10_CORRUPT(root=dataset_path, cortype=cor_type, transform=trans))
        num_class = 10
    elif (id_dataset == 'cifar100') or (id_dataset == 'cifar100_fourier_aug'):
        trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
        test_data = CIFAR100(root=dataset_path, download=True, train=False, transform=trans)
        cor_test = []
        for idx, cor_type in enumerate(corrupt_list):
            cor_test.append(CIFAR100_CORRUPT(root=dataset_path, cortype=cor_type, transform=trans))
        num_class = 100
    else:
        raise NotImplementedError('in distribution dataset should be CIFAR10 or CIFAR100.')

    distribution = create_dirichlet_distribution(alpha, num_client, num_class, seed)
    id_split = split_by_distribution(np.array(test_data.targets), distribution)
    cor_split = split_by_distribution(np.array(cor_test[0].targets), distribution)
    id_datasets = [Subset(test_data, id_split[idx]) for idx in range(num_client)]
    cor_datasets = [
        {cor_type: Subset(cor_test[idx], cor_split[client_idx]) for idx, cor_type in enumerate(corrupt_list)}
        for client_idx in range(num_client)]

    logging.info(f'-------- dirichlet distribution with alpha = {alpha}, {num_client} clients --------')
    logging.info(f'in-distribution test datasets: {[len(dataset) for dataset in id_datasets]}')
    return id_datasets, cor_datasets, num_class


def load_test_ood(dataset_path, ood_dataset, seed, partial):
    random_number_generator = np.random.default_rng(seed)
    ood_data = load_ood_dataset(dataset_path, ood_dataset)

    if partial:
        idx = random.sample([i for i in range(len(ood_data))], int(0.2 * len(ood_data)))
        ood_data = Subset(ood_data, idx)
        logging.info(f'out of distribution test dataset\'s length: {len(ood_data)}')
        return ood_data
    else:
        logging.info(f'out of distribution test dataset\'s length: {len(ood_data)}')
        return ood_data
