import copy, logging
import json
from dataclasses import dataclass
import os
import random
from typing import Dict, List, Optional, Tuple, Union
from torch.utils.data import Dataset
from src.datasets import *
import torchvision
from omegaconf import DictConfig
import numpy as np
from src.utils.dirichlet_non_iid import non_iid_partition_with_dirichlet_distribution

log = logging.getLogger(__name__)


@dataclass(frozen=True)
class RawSplittedDataset:
    samples: List[np.array]
    targets: List[np.array]


@dataclass(frozen=True)
class RawDataset:
    samples: np.array
    targets: np.array


GeneralRawDataset = Union[RawSplittedDataset, RawDataset]

def get_emnist_data(root: str, train: bool) -> GeneralRawDataset:
    filename = 'train.json' if train else 'test.json'
    file = os.path.join(root, 'EMNIST', filename)
    data = json.load(open(file))
    samples, targets = [np.array(d) for d in data['x']], [np.array(d).flatten() for d in data['y']]  # already splitted
    if not train:
        samples, targets = np.vstack(samples), np.concatenate(targets)
        dataset = RawDataset(samples, targets)
    else:
        dataset = RawSplittedDataset(samples, targets)
    return dataset

def get_shakespeare_data(root: str, train: bool, version: str = 'iid') -> GeneralRawDataset:
    assert version in ['niid', 'iid'], "Unknown version"
    filename = f'train_sampled_{version}.json' if train else 'test_sampled.json'
    file = os.path.join(root, 'shakespeare', filename)
    data = json.load(open(file))
    samples, targets = [np.array(d) for d in data['x']], [np.array(d).flatten() for d in data['y']]  # already splitted
    if not train:
        samples, targets = np.vstack(samples), np.vstack(targets).flatten()
        dataset = RawDataset(samples, targets)
    else:
        dataset = RawSplittedDataset(samples, targets)
    return dataset


def get_stackoverflow_data(root: str, train: bool, version: str = 'iid',
                           subsample_dim: Optional[int] = None) -> GeneralRawDataset:
    from tqdm import tqdm
    assert version in ['niid', 'iid']
    datasets_path = os.path.join(root, 'stackoverflow')
    if train:
        train_dir = os.path.join(datasets_path, 'train', f'soverflow_{version}')
        files = os.listdir(train_dir)
        files = [f for f in files if f.endswith('.json')]
        samples = []
        targets = []
        if len(files) > 8:
            files = ['train_0.json', 'train_1.json', 'train_2.json', 'train_3.json']
        for f in tqdm(files, desc='Loading training files'):
            file = open(os.path.join(train_dir, f))
            training_dict = json.load(file)
            samples += training_dict['x']
            targets += training_dict['y']
            file.close()
        dataset = RawSplittedDataset(samples, targets)
    else:
        test_dir = os.path.join(datasets_path, 'test')
        files = os.listdir(test_dir)
        files = [f for f in files if f.endswith('.json')]
        samples = []
        targets = []
        for f in tqdm(files[:4], desc='Loading test files'):
            file = open(os.path.join(test_dir, f))
            test_dict = json.load(file)
            samples += test_dict['x']
            targets += test_dict['y']
        # subsample
        subsample_dim = subsample_dim if subsample_dim else len(samples)
        chosen = np.random.choice(len(samples), subsample_dim, False)
        dataset = RawDataset(np.array(samples)[chosen], np.array(targets)[chosen])
    return dataset


def get_from_torchvision(torchvision_classname: str, **torch_kwargs) -> RawDataset:
    import torchvision
    torch_class = eval(torchvision_classname)
    torch_dataset = torch_class(**torch_kwargs)
    return RawDataset(torch_dataset.data, np.array(torch_dataset.targets))


def get_dataset(getter_fn: DictConfig) -> Dict[str, GeneralRawDataset]:
    fn = eval(getter_fn.name)
    call_args = getter_fn.args
    call_args_train = {**call_args, **getter_fn.arg_train}
    call_args_test = {**call_args, **getter_fn.arg_test}
    train_dataset = fn(**call_args_train)
    test_dataset = fn(**call_args_test)
    dataset = {'train': train_dataset, 'test': test_dataset}
    if getter_fn.arg_val:
        call_args_val = {**call_args, **getter_fn.arg_val}
        val_dataset = fn(**call_args_val)
        dataset.update({'val': val_dataset})
    return dataset


def create_non_iid(dataset: RawDataset, num_clients: int, shard_size: int) -> List[RawDataset]:
    train_img = dataset.samples
    train_label = dataset.targets
    train_sorted_index = np.argsort(train_label)
    train_img = train_img[train_sorted_index]
    train_label = train_label[train_sorted_index]

    shard_start_index = [i for i in range(0, len(train_img), shard_size)]
    random.shuffle(shard_start_index)
    log.info(f"divide data into {len(shard_start_index)} shards of size {shard_size}")

    num_shards = len(shard_start_index) // num_clients
    local_datasets = []
    for client_id in range(num_clients):
        _index = num_shards * client_id
        img = np.concatenate([
            train_img[shard_start_index[_index +
                                        i]:shard_start_index[_index + i] +
                                           shard_size] for i in range(num_shards)
        ],
            axis=0)

        label = np.concatenate([
            train_label[shard_start_index[_index +
                                          i]:shard_start_index[_index +
                                                               i] +
                                             shard_size] for i in range(num_shards)
        ],
            axis=0)
        local_datasets.append(RawDataset(img, label))

    return local_datasets


def create_using_dirichlet_distr(dataset: RawDataset, num_clients: int, shard_size: int, dataset_num_classes: int,
                                 alpha: float, max_iter: int, rebalance: bool):
    d = non_iid_partition_with_dirichlet_distribution(
        np.array(dataset.targets), num_clients, dataset_num_classes, alpha, max_iter)

    if rebalance:
        storage = []
        for i in range(len(d)):
            if len(d[i]) > (shard_size):
                difference = round(len(d[i]) - (shard_size))
                toSwitch = np.random.choice(
                    d[i], difference, replace=False).tolist()
                storage += toSwitch
                d[i] = list(set(d[i]) - set(toSwitch))

        for i in range(len(d)):
            if len(d[i]) < (shard_size):
                difference = round((shard_size) - len(d[i]))
                toSwitch = np.random.choice(
                    storage, difference, replace=False).tolist()
                d[i] += toSwitch
                storage = list(set(storage) - set(toSwitch))

        for i in range(len(d)):
            if len(d[i]) != (shard_size):
                log.warning(f'There are some clients with more than {shard_size} images')

    # Lista contenente per ogni client un'istanza di Cifar10LocalDataset ->local_datasets[client]
    local_datasets = []
    for client_id in d.keys():
        # img = np.concatenate( [train_img[list_indexes_per_client_subset[client_id][c]] for c in range(n_classes)],axis=0)
        img = dataset.samples[d[client_id]]
        # label = np.concatenate( [train_label[list_indexes_per_client_subset[client_id][c]] for c in range(n_classes)],axis=0)
        label = dataset.targets[d[client_id]]
        local_datasets.append(RawDataset(img, label))

    return local_datasets


def split_dirichlet(dataset: RawDataset, num_splits: int, shard_size: int, alpha: float, max_iter: int,
                    dataset_num_classes: int, rebalance: bool) -> List[RawDataset]:
    if alpha == 0:  # Non-IID
        local_datasets = create_non_iid(dataset, num_splits, shard_size)
    else:
        local_datasets = create_using_dirichlet_distr(dataset, num_splits, shard_size, dataset_num_classes, alpha,
                                                      max_iter, rebalance)
    return local_datasets


def split_fixed(dataset: RawSplittedDataset, num_splits: int, shard_size: int) -> List[RawDataset]:
    if num_splits == 1:
        # merge all samples
        all_samples, all_targets = [], []
        all_samples.extend(dataset.samples)
        all_targets.extend(dataset.targets)
        return [RawDataset(np.array(all_samples), np.array(all_samples))]
    assert len(dataset.samples) == num_splits, "Number of requested splits is not equal to number of actual splits"
    return [RawDataset(samples, targets) for samples, targets in zip(dataset.samples, dataset.targets)]


def split_dataset(split_fn: DictConfig, num_samples: int, dataset: GeneralRawDataset, num_splits: int) \
        -> List[RawDataset]:
    shard_size = num_samples // num_splits
    assert shard_size > 0, f"Not enough samples for {num_splits} splits"
    fn = eval(split_fn.name)
    return fn(dataset, num_splits, shard_size, **split_fn.args)


def extract_exemplars(dataset: RawDataset, num_exemplars: int, num_classes: int) -> \
        Tuple[RawDataset, Optional[RawDataset]]:
    if num_exemplars <= 0:
        return dataset, None
    num_img_per_class = len(dataset.samples) // num_classes
    exemplars_per_class = num_exemplars // num_classes
    assert exemplars_per_class > 0, "Number of examplars is less than classes"
    sorted_index = np.argsort(dataset.targets)
    samples = dataset.samples[sorted_index]
    targets = dataset.targets[sorted_index]
    subset_indexes = []
    for i in range(num_classes):
        subset_indexes += list(
            np.random.choice(num_img_per_class, exemplars_per_class, replace=False) + num_img_per_class * i)

    exemplars_samples = samples[subset_indexes]
    exemplars_targets = targets[subset_indexes]

    samples = np.delete(samples, subset_indexes, axis=0)
    targets = np.delete(targets, subset_indexes)

    return RawDataset(samples, targets), RawDataset(exemplars_samples, exemplars_targets)


def create_datasets(dataset_info: DictConfig, num_splits: int = 1) \
        -> Tuple[List[Dataset], Dataset, Dataset, Optional[RawDataset]]:
    dataset_class = eval(dataset_info.classname)
    raw_set = get_dataset(dataset_info.getter_fn)
    raw_set['test'], exemplars = extract_exemplars(raw_set['test'], dataset_info.args.num_exemplars,
                                                   dataset_info.args.num_classes)
    splits = split_dataset(dataset_info.split_fn, dataset_info.args.num_train_samples, raw_set['train'], num_splits)
    datasets_train = [dataset_class(d.samples, d.targets, dataset_info.args.num_classes, train=True) for d in splits]
    dataset_test = dataset_class(raw_set['test'].samples, raw_set['test'].targets, dataset_info.args.num_classes,
                                 train=False)
    dataset_val = dataset_class(raw_set['val'], train=False) if 'val' in raw_set else None

    return datasets_train, dataset_test, dataset_val, exemplars
