# !/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.9

from torchvision import datasets
from torchvision.transforms import transforms
from options import args_parser
from utils.long_tailed_setting import train_long_tail
from utils.dataset import classify_label, show_clients_data_distribution, Indices2Dataset, get_class_num
from utils.sample_dirichlet import clients_indices
from utils.gradient_matching import match_loss
from utils.param_aug import DiffAugment
import numpy as np
from torch import max, eq, no_grad, unsqueeze
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from torch.utils.data.dataloader import DataLoader
from backbone.ResNet10 import resnet10
from tqdm import tqdm
import copy
import torch
import random
import torch.nn.functional as F
import time


class Global(object):
    def __init__(self,
                 num_classes: int,
                 device: str,
                 args,
                 length_of_hp):
        self.device = device
        self.num_classes = num_classes
        self.dim = args.dim
        # Hyper-Prototypes
        self.length_of_hp = length_of_hp
        self.rounds_of_hp = args.rounds_of_hp
        self.hyper_hp = torch.randn(size=(args.num_classes * self.length_of_hp, self.dim), dtype=torch.float,
                                       requires_grad=True, device=args.device)
        self.label_syn = torch.tensor([np.ones(self.length_of_hp) * i for i in range(args.num_classes)], 
                                      dtype=torch.long, 
                                      requires_grad=False, device=args.device).view(-1) 
        self.optimizer_hp = SGD([self.hyper_hp,], lr=args.lr_hp) # optimizing hyper_prototypes 
        self.criterion = CrossEntropyLoss().to(args.device)
        # the global model
        self.syn_model = resnet10(nclasses=args.num_classes).to(device)


    def obtain_hyper_prototypes(self, args, list_clients_gradient):
        self.syn_model.eval()
        global_cls = self.syn_model.classifier 

        gw_real_all = {class_index: [] for class_index in range(self.num_classes)}
        for gradient_one in list_clients_gradient: # each client average gradient = c*512
            for class_num, gradient in gradient_one.items():
                gw_real_all[class_num].append(gradient)
        # gw_real_all each item is [client0, 1, 2, ...] & average gradient for class c
        gw_real_avg = {class_index: [] for class_index in range(args.num_classes)}

        # aggregate the gradients from all clients
        for i in range(args.num_classes):
            list_one_class_client_gradient = gw_real_all[i]
            if len(list_one_class_client_gradient) > 0:
                gradients_stack = torch.stack(list_one_class_client_gradient, dim=0) 
                gw_real_avg[i] = gradients_stack.mean(dim=0) # 512
            else:
                gw_real_avg[i] = None
        
        # update hyper-prototypes
        for ep in range(self.rounds_of_hp):
            loss_virtual = torch.tensor(0.0).to(args.device)
            for c in range(args.num_classes):
                if gw_real_avg[c] is not None:
                    hyper_hp = self.hyper_hp[c * self.length_of_hp:(c + 1) * self.length_of_hp].reshape(
                        (self.length_of_hp, self.dim))
                    lab_syn = torch.ones((self.length_of_hp,), device=args.device, dtype=torch.long) * c
                    output_syn = global_cls(hyper_hp)
                    loss_syn = self.criterion(output_syn, lab_syn)
                    # compute the HP gradients of class c
                    gw_syn = torch.autograd.grad(loss_syn, hyper_hp, create_graph=True)[0]
                    loss_virtual += match_loss(gw_syn, gw_real_avg[c], args)
            self.optimizer_hp.zero_grad()
            loss_virtual.backward()
            self.optimizer_hp.step()


    def initialize_for_model_fusion(self, list_dicts_local_params: list, list_nums_local_data: list):
        aggregated_global_params = copy.deepcopy(list_dicts_local_params[0])
        for name_param in list_dicts_local_params[0]:
            list_values_param = []
            for dict_local_params, num_local_data in zip(list_dicts_local_params, list_nums_local_data):
                list_values_param.append(dict_local_params[name_param] * num_local_data)
            value_global_param = sum(list_values_param) / sum(list_nums_local_data)
            aggregated_global_params[name_param] = value_global_param
        return aggregated_global_params


    def global_eval(self, data_test, batch_size_test):
        self.syn_model.eval()
        with no_grad():
            test_loader = DataLoader(data_test, batch_size_test)
            num_corrects = 0
            for data_batch in test_loader:
                images, labels = data_batch
                images, labels = images.to(self.device), labels.to(self.device)
                _, outputs = self.syn_model(images)
                _, predicts = max(outputs, -1)
                num_corrects += sum(eq(predicts.cpu(), labels.cpu())).item()
            accuracy = num_corrects / len(data_test)
        return accuracy
    

    def download_params(self):
        return self.syn_model.state_dict()


    def get_hp(self):
        global_hp = self.hyper_hp.detach().clone()
        return global_hp


class Local(object):
    def __init__(self,
                 data_client,
                 class_list: int,
                 hyper_prototypes):
        args = args_parser()
        self.args = args
        self.data_client = data_client
        self.device = args.device
        self.class_compose = class_list
        self.criterion = CrossEntropyLoss().to(args.device)
        self.local_model = resnet10(nclasses=args.num_classes).to(args.device)
        self.optimizer = SGD(self.local_model.parameters(), lr=args.lr_local_training, momentum=0.9, weight_decay=1e-5)
        self.hp_info_t = args.hp_info_t
        self.hp = {class_index: [] for class_index in range(args.num_classes)}
        for c in range(args.num_classes):
            self.hp[c] = hyper_prototypes[c * args.length_of_hp:(c + 1) * args.length_of_hp].reshape(
                        (args.length_of_hp, args.dim))


    @torch.no_grad()
    def compute_margin(self):
        client_lp = {class_index: [] for class_index in range(self.args.num_classes)}
        client_lp_avg = {class_index: [] for class_index in range(self.args.num_classes)}
        lp_loader = DataLoader(dataset=self.data_client, batch_size=100)
        for data_batch in lp_loader:
            images, labels = data_batch
            images, labels = images.to(self.device), labels.to(self.device)
            features, _ = self.local_model(images)
            for f_c, c in zip(features, labels):
                client_lp[c.item()].append(f_c)
        for i in range(self.args.num_classes):
            if not client_lp[i]:
                client_lp_avg[i] = None
            else:
                all_lp = client_lp[i]
                all_lp_stack = torch.stack(all_lp, dim=0)
                lp_c = all_lp_stack.mean(dim=0)
                client_lp_avg[i] = lp_c 
        lp_list = [client_lp_avg[i] for i in range(self.args.num_classes) if client_lp_avg[i] is not None]
        lp_stack = torch.stack(lp_list, dim=0)
        pairwise_products = torch.matmul(lp_stack, lp_stack.T)
        pairwise_no_diag = pairwise_products.clone()
        pairwise_no_diag.fill_diagonal_(0)
        result_tensor = pairwise_no_diag / ((self.args.num_classes - 1) ** 2)

        return torch.sum(result_tensor)


    def compute_gradient(self, args):
        list_class, per_class_compose = get_class_num(self.class_compose)  # class
        images_all = []
        labels_all = []
        indices_class = {class_index: [] for class_index in list_class}

        images_all = [unsqueeze(self.data_client[i][0], dim=0) for i in range(len(self.data_client))]
        labels_all = [self.data_client[i][1] for i in range(len(self.data_client))]
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)
        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

        def get_images(c, n):  # get random from class
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]

        self.local_model.eval()
        criterion = CrossEntropyLoss().to(args.device)

        # gradients of all classes
        truth_gradient_all = {index: [] for index in list_class}
        truth_gradient_avg = {index: [] for index in list_class}

        # random choose to repeat
        for num_compute in range(10):
            for c, num in zip(list_class, per_class_compose):
                img_real = get_images(c, args.batch_real)
                if args.dsa: # data transform
                    seed = int(time.time() * 1000) % 100000
                    img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
                feature_real, output_real = self.local_model(img_real)
                loss_real = criterion(output_real, lab_real)
                # compute the gradients from single class
                gw_real = torch.autograd.grad(loss_real, feature_real)[0].detach().clone()
                # print(gw_real[0].shape) 
                truth_gradient_all[c].append(gw_real)

        for i in list_class:
            gradient_all = truth_gradient_all[i] # list of tensors
            gradient_stack = torch.stack(gradient_all, dim=0)
            gw_real_avg = gradient_stack.mean(dim=(0, 1))
            truth_gradient_avg[i] = gw_real_avg # 512
        
        return truth_gradient_avg
    

    def calculate_infonce(self, f_now, pos_hp, neg_hp, k_margin):
        f_pos = pos_hp.to(self.device) # [length_of_hp, dim]
        f_neg = neg_hp.to(self.device) # [class_num-1, length_of_hp, dim]
        f_combin_hp = torch.cat((f_pos.unsqueeze(0), f_neg), dim=0) 
        f_now_exp = f_now.unsqueeze(0) # [1, 1, dim]
        f_now_exp = f_now_exp.expand(f_combin_hp.size(0), f_combin_hp.size(1), -1) 
        cos_sim = F.cosine_similarity(f_now_exp, f_combin_hp, dim=2) 
        l = cos_sim.mean(dim=1) 
        l = l / self.hp_info_t
        exp_l = torch.exp(l)
        exp_l = exp_l.view(1, -1) 
        pos_mask = [1 for _ in range(1)] + [0 for _ in range(f_combin_hp.size(0)-1)]
        pos_mask = torch.tensor(pos_mask, dtype=torch.float).to(self.device)
        pos_mask = pos_mask.view(1, -1)
        pos_l = exp_l * pos_mask
        sum_pos_l = pos_l.sum(1)
        sum_exp_l = exp_l.sum(1) + k_margin 
        infonce_loss = -torch.log(sum_pos_l / sum_exp_l)
        return infonce_loss


    def local_train(self, args, global_params):
        transform_train = transforms.Compose(
            [transforms.RandomCrop(32, padding=4), 
            transforms.RandomHorizontalFlip()])
        self.local_model.load_state_dict(global_params)
        self.local_model.train()
        k_margin = self.compute_margin()

        for _ in range(args.num_epochs_local_training):
            data_loader = DataLoader(dataset=self.data_client,
                                     batch_size=args.batch_size_local_training,
                                     shuffle=True)
            for data_batch in data_loader:
                images, labels = data_batch
                images, labels = images.to(self.device), labels.to(self.device)
                images = transform_train(images)
                features, outputs = self.local_model(images)

                loss_hp = None
                loss_ce = self.criterion(outputs, labels)
                for idx, label in enumerate(labels):
                    label = label.item()
                    f_now = features[idx].unsqueeze(0)
                    pos_hp = self.hp[label] # [length_of_hp, dim]
                    neg_hp = torch.stack([self.hp[c] for c in range(self.args.num_classes) if c != label], dim=0)
                    loss_hpcl = self.calculate_infonce(f_now, pos_hp, neg_hp, k_margin)
                    avg_hp = self.hp[label].mean(dim=0) # [dim]
                    loss_hpal = F.smooth_l1_loss(features[idx], avg_hp, reduction='mean')
                    loss_instance = loss_hpcl + loss_hpal
                    if loss_hp is None:
                        loss_hp = loss_instance
                    else:
                        loss_hp += loss_instance

                loss_hp = loss_hp / images.size(0)
                loss = loss_ce + loss_hp
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        return self.local_model.state_dict()


def FedHPro():
    args = args_parser()
    print(
        'imbalance_factor: {ib}, non_iid_factor: {non_iid}\n'
        'length_of_hp: {length_of_hp}, rounds_of_hp: {rounds_of_hp}\n'.format(
        ib=args.imb_factor,
        non_iid=args.non_iid_alpha,
        length_of_hp=args.length_of_hp,
        rounds_of_hp=args.rounds_of_hp))
    
    random_state = np.random.RandomState(args.seed)

    # load dataset
    transform_all = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    data_local_training = datasets.CIFAR10(args.path_cifar10, train=True, download=True, transform=transform_all)
    data_global_test = datasets.CIFAR10(args.path_cifar10, train=False, transform=transform_all)

    # distribute data
    list_label2indices = classify_label(data_local_training, args.num_classes)
    # heterogeneous and long_tailed setting
    _, list_label2indices_train_new = train_long_tail(copy.deepcopy(list_label2indices), args.num_classes,
                                                      args.imb_factor, args.imb_type)
    list_client2indices = clients_indices(copy.deepcopy(list_label2indices_train_new), args.num_classes,
                                          args.num_clients, args.non_iid_alpha, args.seed)
    original_dict_per_client = show_clients_data_distribution(data_local_training, list_client2indices,
                                                              args.num_classes)

    global_model = Global(num_classes=args.num_classes,
                          device=args.device,
                          args=args,
                          length_of_hp=args.length_of_hp)

    total_clients = list(range(args.num_clients))
    indices2data = Indices2Dataset(data_local_training)

    trained_acc = []

    '''FL global training'''
    for r in tqdm(range(1, args.num_rounds+1), desc='FedHPro-training'):
        global_params = global_model.download_params()
        global_hp = global_model.get_hp()
        # print(global_hp.shape) 
        online_clients = random_state.choice(total_clients, args.num_online_clients, replace=False)
        list_clients_gradient = []
        list_dicts_local_params = []
        list_nums_local_data = []

        # client local training
        for client in online_clients:
            indices2data.load(list_client2indices[client])
            data_client = indices2data
            list_nums_local_data.append(len(data_client))
            local_model = Local(data_client=data_client,
                                class_list=original_dict_per_client[client],
                                hyper_prototypes=global_hp)

            # local update
            local_params = local_model.local_train(args, copy.deepcopy(global_params))
            list_dicts_local_params.append(copy.deepcopy(local_params))
            # compute gradients from clients
            truth_gradient = local_model.compute_gradient(args) 
            list_clients_gradient.append(copy.deepcopy(truth_gradient))

        # aggregating
        aggregated_params = global_model.initialize_for_model_fusion(list_dicts_local_params, list_nums_local_data)
        global_model.syn_model.load_state_dict(copy.deepcopy(aggregated_params))
        # updating hyper-prototypes
        global_model.obtain_hyper_prototypes(args, list_clients_gradient)
        # global eval
        one_train_acc = global_model.global_eval(data_global_test, args.batch_size_test)
        # print(one_train_acc)
        trained_acc.append(one_train_acc)
        if r % 10 == 0:
            print(trained_acc)

    print(trained_acc)


if __name__ == '__main__':
    torch.manual_seed(55)  # cpu
    torch.cuda.manual_seed(55)  # gpu
    np.random.seed(55)  # numpy
    random.seed(55)  # random and transforms
    torch.backends.cudnn.deterministic = True  # cudnn
    args = args_parser()
    FedHPro()