import os
import torch
import torch.nn as nn
import numpy as np


from datasets import load_dataset
from sklearn.metrics import accuracy_score
from torchvision.transforms import Normalize, Resize, ToTensor, Compose
from transformers import ViTImageProcessor, ViTForImageClassification, AutoImageProcessor, AutoModelForImageClassification





# apply transforms to PIL Image and store it to 'pixels' key
 
 
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))



def get_data(model_name='microsoft/resnet-18', dataset_name='cifar10', norm=None, batch_size=32, num_workers=32, proxy=7890):
    
    os.environ['http_proxy'] = 'http://127.0.0.1:%s' % proxy
    os.environ['https_proxy'] = 'http://127.0.0.1:%s' % proxy

    os.environ['HF_DATASETS_CACHE'] = '/data/usr1/hgfc/dataset/'
    os.environ['HF_MODEL_CACHE'] = '/data/usr1/hgfc/model/'

    processor = AutoImageProcessor.from_pretrained(model_name)
    mu, sigma = processor.image_mean, processor.image_std
    size = processor.size
    norm = Normalize(mean=mu, std=sigma)
    
    
    if dataset_name == 'cifar10':
        dstrain, dstest = load_dataset('cifar10', split=['train[:50000]','test[:10000]'])
        num_labels=10
        
        def collate_fn(examples):
            pixels = torch.stack([torch.tensor(processor(example['img'])['pixel_values'][0]) for example in examples])
            labels = torch.tensor([example['label'] for example in examples])
            return {'pixel_values': pixels, 'labels': labels}
        
        itos = dict((k,v) for k,v in enumerate(dstrain.features['label'].names))
        stoi = dict((v,k) for k,v in enumerate(dstrain.features['label'].names))
        
    elif dataset_name == 'cifar100':
        dstrain, dstest = load_dataset('cifar100', split=['train[:50000]','test[:10000]'])
        num_labels=100

        def collate_fn(examples):
            pixels = torch.stack([torch.tensor(processor(example['img'])['pixel_values'][0]) for example in examples])
            labels = torch.tensor([example['fine_label'] for example in examples])
            return {'pixel_values': pixels, 'labels': labels}
        
        itos = dict((k,v) for k,v in enumerate(dstrain.features['fine_label'].names))
        stoi = dict((v,k) for k,v in enumerate(dstrain.features['fine_label'].names))
    
    dsvalid = dstest
    
        
    model = AutoModelForImageClassification.from_pretrained(model_name, num_labels=num_labels, ignore_mismatched_sizes=True, id2label=itos, label2id=stoi)
    
    if model_name.__contains__('microsoft/resnet-18'):
        model.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, num_labels, bias=True)
        )
        print(f'ResNet18 changed classifier head, dim of output = {num_labels}.')
    elif model_name.__contains__('microsoft/resnet-50'):
        model.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2048, num_labels, bias=True)
        )
        print(f'ResNet50 changed classifier head, dim of output = {num_labels}.')
    elif model_name.__contains__('vit'):
        print(f'{model_name} changed classifier head, dim of output = {num_labels}.')\
    
    dltrain = torch.utils.data.DataLoader(dstrain, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn)
    dlvalid = torch.utils.data.DataLoader(dsvalid, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn)
    
    return dltrain, dlvalid, model, processor



def dirichlet_split_noniid(train_labels, alpha, n_clients):
    '''
    按照参数为alpha的Dirichlet分布将样本索引集合划分为n_clients个子集
    '''
    n_classes = train_labels.max()+1
    # (K, N) 类别标签分布矩阵X，记录每个类别划分到每个client去的比例
    label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
    # (K, ...) 记录K个类别对应的样本索引集合
    class_idcs = [np.argwhere(train_labels == y).flatten()
                  for y in range(n_classes)]

    # 记录N个client分别对应的样本索引集合
    client_idcs = [[] for _ in range(n_clients)]
    for k_idcs, fracs in zip(class_idcs, label_distribution):
        # np.split按照比例fracs将类别为k的样本索引k_idcs划分为了N个子集
        # i表示第i个client，idcs表示其对应的样本索引集合idcs
        for i, idcs in enumerate(np.split(k_idcs,
                                          (np.cumsum(fracs)[:-1]*len(k_idcs)).
                                          astype(int))):
            client_idcs[i] += [idcs]

    client_idcs = [np.concatenate(idcs) for idcs in client_idcs]

    return client_idcs




def get_feddata(random_seed=4989, model_name='microsoft/resnet-18', alpha=None, dataset_name='cifar10', norm=None, batch_size=32, num_workers=32, K=5, proxy=7890):
    
    os.environ['http_proxy'] = 'http://127.0.0.1:%s' % proxy
    os.environ['https_proxy'] = 'http://127.0.0.1:%s' % proxy

    os.environ['HF_DATASETS_CACHE'] = '/data/usr1/hgfc/dataset/'
    os.environ['HF_MODEL_CACHE'] = '/data/usr1/hgfc/model/'
    
    processor = AutoImageProcessor.from_pretrained(model_name)
    mu, sigma = processor.image_mean, processor.image_std
    size = processor.size
    norm = Normalize(mean=mu, std=sigma)
    
    if dataset_name == 'cifar10':
        dstrain, dstest = load_dataset('cifar10', split=['train[:50000]','test[:10000]'])
        num_labels = 10
        
        if alpha < 0:
            dstrains = torch.utils.data.random_split(dstrain, [50000 // K] * K)
        else:
            print('Making subsets according a Dirichlet distribution with alpha=%.4f' % alpha)
            labels = np.array([dstrain.__getitem__(_)['label'] for _ in range(len(dstrain))])
            client_idcs = dirichlet_split_noniid(labels, alpha=alpha, n_clients=K)
            dstrains = [torch.utils.data.Subset(dstrain, [int(_) for _ in np.random.choice(client_idcs[ik], 50000 // K, replace=True)]) for ik in range(K)]
        
        def collate_fn(examples):
            pixels = torch.stack([torch.tensor(processor(example['img'])['pixel_values'][0]) for example in examples])
            labels = torch.tensor([example['label'] for example in examples])
            return {'pixel_values': pixels, 'labels': labels}
        
        dsvalid = dstest
        itos = dict((k,v) for k,v in enumerate(dstrain.features['label'].names))
        stoi = dict((v,k) for k,v in enumerate(dstrain.features['label'].names))
        
    elif dataset_name == 'cifar100':
        dstrain, dstest = load_dataset('cifar100', split=['train[:50000]','test[:10000]'])
        dstrains = torch.utils.data.random_split(dstrain, [50000 // K] * K)
        num_labels=100
        
        if alpha < 0:
            dstrains = torch.utils.data.random_split(dstrain, [50000 // K] * K)
        else:
            print('Making subsets according a Dirichlet distribution with alpha=%.4f' % alpha)
            labels = np.array([dstrain.__getitem__(_)['fine_label'] for _ in range(len(dstrain))])
            client_idcs = dirichlet_split_noniid(labels, alpha=alpha, n_clients=K)
            dstrains = [torch.utils.data.Subset(dstrain, [int(_) for _ in np.random.choice(client_idcs[ik], 50000 // K, replace=True)]) for ik in range(K)]

        def collate_fn(examples):
            pixels = torch.stack([torch.tensor(processor(example['img'])['pixel_values'][0]) for example in examples])
            labels = torch.tensor([example['fine_label'] for example in examples])
            return {'pixel_values': pixels, 'labels': labels}
    
        dsvalid = dstest
        itos = dict((k,v) for k,v in enumerate(dstrain.features['fine_label'].names))
        stoi = dict((v,k) for k,v in enumerate(dstrain.features['fine_label'].names))

    torch.manual_seed(random_seed)
    model = AutoModelForImageClassification.from_pretrained(model_name, num_labels=num_labels, ignore_mismatched_sizes=True, id2label=itos, label2id=stoi)
    
    if model_name.__contains__('microsoft/resnet-18'):
        model.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, num_labels, bias=True)
        )
        print(f'ResNet18 changed classifier head, dim of output = {num_labels}.')
    elif model_name.__contains__('microsoft/resnet-50'):
        model.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2048, num_labels, bias=True)
        )
        print(f'ResNet50 changed classifier head, dim of output = {num_labels}.')
    elif model_name.__contains__('vit'):
        # model.classifier = nn.Sequential(
        #     nn.Flatten(),
        #     nn.Linear(768, num_labels, bias=True)
        # )
        print(f'{model_name} changed classifier head, dim of output = {num_labels}.')
    
    dltrain = [torch.utils.data.DataLoader(_, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn) for _ in dstrains]
    dlvalid = torch.utils.data.DataLoader(dsvalid, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn)
    
    return dltrain, dlvalid, model, processor