from __future__ import annotations

import copy
import torch
import numpy as np
# import models
import datasets
from config import cfg
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from utils.api import (
    collate, 
    to_device
)

from _typing import (
    DatasetType
)

from models.api import make_batchnorm

data_stats = {
    'MNIST': ((0.1307,), (0.3081,)), 
    'FashionMNIST': ((0.2859,), (0.3530,)),
    'CIFAR10': ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    'CIFAR100': ((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
    'SVHN': ((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)),
    'CINIC10': ((0.4789, 0.4723, 0.4305), (0.2421, 0.2383, 0.2587))
}


def fetch_dataset(data_name):
    dataset = {}
    print('fetching data {}...'.format(data_name))
    root = './data/{}'.format(data_name)
    if data_name in ['MNIST', 'FashionMNIST']:
        dataset['train'] = eval('datasets.{}(root=root, split=\'train\', '
                                'transform=datasets.Compose([transforms.ToTensor()]))'.format(data_name))
        dataset['test'] = eval('datasets.{}(root=root, split=\'test\', '
                               'transform=datasets.Compose([transforms.ToTensor()]))'.format(data_name))
        # dataset['train'].transform = datasets.Compose([
        #     transforms.ToTensor(),
        #     transforms.Normalize(*data_stats[data_name])])
        if cfg['malicious_way'] == 'labelflipping' and cfg['malicious_clients_ratio'] > 0:
            dataset['train'].transform = datasets.Compose([
                # transforms.RandomVerticalFlip(p=cfg['malicious_data_in_clients_ratio']),
                transforms.ToTensor(),
                transforms.Normalize(*data_stats[data_name])]
            )
        else:
            dataset['train'].transform = datasets.Compose([
                transforms.ToTensor(),
                transforms.Normalize(*data_stats[data_name])]
            )
        dataset['test'].transform = datasets.Compose([
            transforms.ToTensor(),
            transforms.Normalize(*data_stats[data_name])])
    elif data_name in ['CIFAR10', 'CIFAR100', 'CINIC10']:
        dataset['train'] = eval('datasets.{}(root=root, split=\'train\', '
                                'transform=datasets.Compose([transforms.ToTensor()]))'.format(data_name))
        dataset['test'] = eval('datasets.{}(root=root, split=\'test\', '
                               'transform=datasets.Compose([transforms.ToTensor()]))'.format(data_name))
        if cfg['malicious_way'] == 'labelflipping' and cfg['malicious_clients_ratio'] > 0:
            dataset['train'].transform = datasets.Compose([
                # transforms.RandomVerticalFlip(p=cfg['malicious_data_in_clients_ratio']),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
                transforms.ToTensor(),
                transforms.Normalize(*data_stats[data_name])]
            )
        else:
            dataset['train'].transform = datasets.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
                transforms.ToTensor(),
                transforms.Normalize(*data_stats[data_name])]
            )
        dataset['test'].transform = datasets.Compose([
            transforms.ToTensor(),
            transforms.Normalize(*data_stats[data_name])])
    elif data_name in ['SVHN']:
        dataset['train'] = eval('datasets.{}(root=root, split=\'train\', '
                                'transform=datasets.Compose([transforms.ToTensor()]))'.format(data_name))
        dataset['test'] = eval('datasets.{}(root=root, split=\'test\', '
                               'transform=datasets.Compose([transforms.ToTensor()]))'.format(data_name))
        dataset['train'].transform = datasets.Compose([
            transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
            transforms.ToTensor(),
            transforms.Normalize(*data_stats[data_name])])
        dataset['test'].transform = datasets.Compose([
            transforms.ToTensor(),
            transforms.Normalize(*data_stats[data_name])])
    else:
        raise ValueError('Not valid dataset name')
    print('data ready')
    return dataset


def input_collate(batch):
    a = batch
    if isinstance(batch[0], dict):
        output = {key: [] for key in batch[0].keys()}
        for b in batch:
            for key in b:
                output[key].append(b[key])
        b = output
        return output
    else:
        return default_collate(batch)


def make_data_loader(dataset, tag, batch_size=None, shuffle=None, sampler=None, batch_sampler=None):
    data_loader = {}
    for k in dataset:
        _batch_size = cfg[tag]['batch_size'][k] if batch_size is None else batch_size[k]
        _shuffle = cfg[tag]['shuffle'][k] if shuffle is None else shuffle[k]
        if sampler is not None:
            data_loader[k] = DataLoader(
                dataset=dataset[k], 
                batch_size=_batch_size, 
                sampler=sampler[k],
                pin_memory=cfg['pin_memory'], 
                num_workers=cfg['num_workers'],
                collate_fn=input_collate, 
                worker_init_fn=np.random.seed(cfg['seed'])
            )
        elif batch_sampler is not None:
            data_loader[k] = DataLoader(
                dataset=dataset[k], 
                batch_sampler=batch_sampler[k],
                pin_memory=cfg['pin_memory'], 
                num_workers=cfg['num_workers'],
                collate_fn=input_collate, 
                worker_init_fn=np.random.seed(cfg['seed'])
            )
        else:
            data_loader[k] = DataLoader(
                dataset=dataset[k], 
                batch_size=_batch_size, 
                shuffle=_shuffle,
                pin_memory=cfg['pin_memory'], 
                num_workers=cfg['num_workers'],
                collate_fn=input_collate, 
                worker_init_fn=np.random.seed(cfg['seed'])
            )

    return data_loader

def split_dataset(
    dataset: dict[str, DatasetType], 
    num_users: int, 
    data_split_mode: str
) -> dict[str, dict[int, list[int]]]:

    data_split = {}
    if 'iid' in cfg['data_split_mode'] and 'non-iid' not in cfg['data_split_mode']:
        data_split['train'] = iid(dataset['train'], num_users)
        data_split['test'] = iid(dataset['test'], num_users)
    elif 'non-iid' in cfg['data_split_mode']:
        data_split['train'] = non_iid(dataset['train'], num_users)
        data_split['test'] = non_iid(dataset['test'], num_users)
    else:
        raise ValueError('Not valid data split mode')
    return data_split

def get_malicious_data_ids(
    data_split,
    dataset,
    logger
):
    candidate_client_ids = [i for i in range(cfg['num_clients'])]
    malicious_client_ids = list(np.random.choice(
        candidate_client_ids, 
        size=int(cfg['malicious_clients_ratio'] * cfg['num_clients']), 
        replace=False)
    )

    print(f'malicious_client_ids: {malicious_client_ids}', flush=True)
    logger.safe(True)
    logger.append(
        {   
            f'malicious_client_ids': malicious_client_ids,                
        }, 
        'train', 
        n=len(malicious_client_ids)
    )
    logger.safe(False)
    logger.reset()

    
    # train_idx, test_idx = idx[:num_train], idx[num_train:]
    # dataset['train'].target_size = 0
    # malicious_data_ids = []
    # for client_id in malicious_client_ids:
    #     malicious_data_ids += copy.deepcopy(data_split[client_id])

    dataset.malicious_data_ids = []
    return malicious_client_ids

def change_labels_to_fault(
    data_split,
    dataset,
    malicious_client_ids
):  
    print('===', f"{cfg['model_tag']}")
    target_set = set(list(range(dataset.target_size)))
    for client_id in malicious_client_ids:
        idx = np.random.permutation((len(data_split[client_id])))

        fault_injection_num = max(1, int(len(data_split[client_id]) * cfg['malicious_data_in_clients_ratio']))
        fault_injection_idx = idx[:fault_injection_num]

        for item in fault_injection_idx:
            sub_idx = data_split[client_id][item]
            original_label = dataset.target[sub_idx]
            # temp_target_set = copy.deepcopy(target_set)
            # if original_label != 1:
            #     new_label = 1
            # else:
            #     new_label = 2
            # new_label = np.random.choice(list(temp_target_set - set([original_label])))

            # cyclic label
            new_label = dataset.target_size - 1 - original_label
            dataset.target[sub_idx] = new_label

    return

def iid(
    dataset: DatasetType, 
    num_users: int
) -> dict[int, list[int]]:

    if 'nonequal' in cfg['data_split_mode']:
        length = len(dataset)
        num_items = int(length / num_users)
        interval = int(num_items / num_users * 2)

        cur_data_size = interval
        data_split, idx = {}, list(range(length))
        for i in range(num_users):
            num_items_i = min(len(idx), cur_data_size)
            data_split[i] = torch.tensor(idx)[torch.randperm(len(idx))[:num_items_i]].tolist()
            idx = list(set(idx) - set(data_split[i]))
            cur_data_size += interval
    elif 'equal' in cfg['data_split_mode']:
        length = len(dataset)
        num_items = int(length / num_users * cfg['train_sample_portion']) 
        data_split, idx = {}, list(range(length))
        # print(f'iid data split: num_users: {num_users}, num_items: {num_items}')
        for i in range(num_users):
            num_items_i = min(len(idx), num_items)
            data_split[i] = torch.tensor(idx)[torch.randperm(len(idx))[:num_items_i]].tolist()
            idx = list(set(idx) - set(data_split[i]))
    else:
        raise ValueError('wrong data_split_mode')
    return data_split


def non_iid(
    dataset: DatasetType, 
    num_users: int
) -> dict[int, list[int]]:

    target = torch.tensor(dataset.target)
    data_split_mode_list = cfg['data_split_mode'].split('-')
    data_split_mode_tag = data_split_mode_list[-2]
    if data_split_mode_tag == 'l':
        # 只有某n类数据
        data_split = {i: [] for i in range(num_users)}
        shard_per_user = int(data_split_mode_list[-1])
        target_idx_split = {}
        shard_per_class = int(shard_per_user * num_users / cfg['target_size'])
        for target_i in range(cfg['target_size']):
            target_idx = torch.where(target == target_i)[0]
            num_leftover = len(target_idx) % shard_per_class
            leftover = target_idx[-num_leftover:] if num_leftover > 0 else []
            new_target_idx = target_idx[:-num_leftover] if num_leftover > 0 else target_idx
            new_target_idx = new_target_idx.reshape((shard_per_class, -1)).tolist()
            for i, leftover_target_idx in enumerate(leftover):
                new_target_idx[i] = new_target_idx[i] + [leftover_target_idx.item()]
            target_idx_split[target_i] = new_target_idx
        target_split = list(range(cfg['target_size'])) * shard_per_class
        target_split = torch.tensor(target_split)[torch.randperm(len(target_split))].tolist()
        target_split = torch.tensor(target_split).reshape((num_users, -1)).tolist()
        for i in range(num_users):
            for target_i in target_split[i]:
                idx = torch.randint(len(target_idx_split[target_i]), (1,)).item()
                data_split[i].extend(target_idx_split[target_i].pop(idx))
    elif data_split_mode_tag == 'd':
        # Dirichlet distributon
        beta = float(data_split_mode_list[-1])
        dir = torch.distributions.dirichlet.Dirichlet(torch.tensor(beta).repeat(num_users))
        min_size = 0
        required_min_size = 10
        N = target.size(0)
        while min_size < required_min_size:
            # print(f'min_size: {min_size}')
            data_split = [[] for _ in range(num_users)]
            for target_i in range(cfg['target_size']):
                target_idx = torch.where(target == target_i)[0]
                proportions = dir.sample()
                proportions = torch.tensor(
                    [p * (len(data_split_idx) < (N / num_users)) for p, data_split_idx in zip(proportions, data_split)])
                proportions = proportions / proportions.sum()
                # torch.cumsum累加最后一个纬度去分index
                a = (torch.cumsum(proportions, dim=-1) * len(target_idx))
                b = (torch.cumsum(proportions, dim=-1) * len(target_idx)).long()
                split_idx = (torch.cumsum(proportions, dim=-1) * len(target_idx)).long().tolist()[:-1]
                split_idx = torch.tensor_split(target_idx, split_idx)
                data_split = [data_split_idx + idx.tolist() for data_split_idx, idx in zip(data_split, split_idx)]
            min_size = min([len(data_split_idx) for data_split_idx in data_split])
        data_split = {i: data_split[i] for i in range(num_users)}

        total = []
        for key,val in data_split.items():
            total.append(len(val))
            
    else:
        raise ValueError('Not valid data split mode tag')
    return data_split


def separate_dataset(
    dataset: DatasetType, 
    idx: list[int]
) -> DatasetType:

    separated_dataset = copy.deepcopy(dataset)
    separated_dataset.data = [dataset.data[s] for s in idx]
    separated_dataset.target = [dataset.target[s] for s in idx]
    separated_dataset.id = list(range(len(separated_dataset.data)))
    return separated_dataset


def make_batchnorm_dataset(dataset):
    dataset = copy.deepcopy(dataset)
    plain_transform = datasets.Compose([
        transforms.ToTensor(),
        transforms.Normalize(*data_stats[cfg['data_name']])
    ])
    dataset.transform = plain_transform
    return dataset


def make_batchnorm_stats(dataset, model, tag):
    with torch.no_grad():
        test_model = copy.deepcopy(model)
        test_model.apply(lambda m: make_batchnorm(m, momentum=None, track_running_stats=True))
        data_loader = make_data_loader({'train': dataset}, tag, shuffle={'train': False})['train']
        test_model.train(True)
        for i, input in enumerate(data_loader):
            input = collate(input)
            input = to_device(input, cfg['device'])
            test_model(input)
    return test_model
