# Modified from: https://github.com/pliang279/LG-FedAvg/blob/master/models/Update.py
# credit goes to: Paul Pu Liang

# !/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import math
import numpy as np
import time
import copy
from models.test import test_img_local
from einops import rearrange, reduce, repeat


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs, name=None):
        self.dataset = dataset
        self.idxs = list(idxs)
        self.name = name

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        if self.name is None:
            image, label = self.dataset[self.idxs[item]]
        elif 'femnist' in self.name:
            image = torch.reshape(torch.tensor(self.dataset['x'][item]), (1, 28, 28))
            label = torch.tensor(self.dataset['y'][item])
        elif 'sent140' in self.name:
            image = self.dataset['x'][item]
            label = self.dataset['y'][item]
        else:
            image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None, indd=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True,
                                    drop_last=True)

        if indd is not None:
            self.indd = indd
        else:
            self.indd = None

        self.dataset = dataset
        self.idxs = idxs
        self.relu = nn.ReLU()

    def exchange(self, net_glob, net_local, grad_glob, grad_local, lamda):
        weights_local = net_local.state_dict()
        weights_glob = net_glob.state_dict()
        keys_dict = {'h3': 'conv3.weight', 'h2': 'conv2.weight', 'h1': 'conv1.weight'}
        # keys_dict = {'h3': 'conv3.weight', 'h2': 'conv2.weight'}
        # keys_dict = {'h3': 'conv3.weight'}
        total_mask = 0
        replace_mask = 0
        for k in keys_dict.keys():
            pooled_grad_glob = torch.mean(grad_glob[k], dim=[0, 2, 3])
            pooled_grad_local = torch.mean(grad_local[k], dim=[0, 2, 3])

            mean_glob = torch.mean(self.relu(pooled_grad_glob))
            mean_local = torch.mean(self.relu(pooled_grad_local))

            zero, one = torch.zeros_like(pooled_grad_local), torch.ones_like(pooled_grad_local)

            mask_glob = torch.where(pooled_grad_glob > 3*mean_glob, one, zero)
            mask_local = torch.where(pooled_grad_local > 3*mean_local, one, zero)

            compensate_values = mask_glob - mask_local
            compensate_masks = torch.where(compensate_values > 0.5, one, zero)

            total_mask += mask_glob.size(0)
            replace_mask += torch.sum(compensate_masks)

            cur_key = keys_dict[k]

            weights_diff = weights_glob[cur_key] - weights_local[cur_key]
            b, c, w, h = weights_diff.size()
            weights_diff = rearrange(weights_diff, "b h n d -> b (h n d)")
            compensate_masks = compensate_masks.unsqueeze(1)
            compensate_masks = repeat(compensate_masks, 'b () -> b n', n=c*w*h)
            weights_diff = self.args.lr * compensate_masks * weights_diff
            weights_diff = torch.reshape(weights_diff, (b, c, w, h))
            weights_local[cur_key] = weights_local[cur_key] + weights_diff

        return weights_local, replace_mask, total_mask

    def train(self, net, net_trans, net_global, w_glob_keys, last=False, dataset_test=None, ind=-1, idx=-1, lr=0.01):
        net_global.eval()
        net.train()
        net_trans.train()
        bias_p = []
        weight_p = []
        for name, p in net.named_parameters():
            if 'bias' in name:
                bias_p += [p]
            else:
                weight_p += [p]
        for name, p in net_trans.named_parameters():
            if 'bias' in name:
                bias_p += [p]
            else:
                weight_p += [p]

        optimizer = torch.optim.SGD(
            [
                {'params': weight_p, 'weight_decay': 0.0001},
                {'params': bias_p, 'weight_decay': 0}
            ],
            lr=lr, momentum=0.5
        )

        local_eps = self.args.local_ep

        if last:
            local_eps = max(10, local_eps - self.args.local_rep_ep)

        head_eps = local_eps - self.args.local_rep_ep
        epoch_loss = []
        num_updates = 0
        total_iters = local_eps * len(self.ldr_train)

        all_replace, all_total = 0, 0
        for iter in range(local_eps):
            done = False

            if (iter < head_eps and self.args.alg == 'fedAKIE') or last:

                for name, param in net.named_parameters():
                    param.requires_grad = False

                for name, param in net_trans.named_parameters():
                    param.requires_grad = True

            # then do local epochs for the representation
            elif iter >= head_eps and self.args.alg == 'fedAKIE' and not last:
                for name, param in net.named_parameters():
                    param.requires_grad = True

                for name, param in net_trans.named_parameters():
                    param.requires_grad = False

            # all other methods update all parameters simultaneously
            elif self.args.alg != 'fedAKIE':
                for name, param in net.named_parameters():
                    param.requires_grad = True

            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):

                images, labels = images.to(self.args.device), labels.to(self.args.device)

                labels = torch.cat([labels, labels], dim=0)

                if iter >= head_eps and self.args.alg == 'fedAKIE' and not last:

                    net.zero_grad()
                    net_trans.zero_grad()
                    net_global.zero_grad()

                    feat_glob = net_global(images)
                    feat_local = net(images)
                    feat_all = torch.cat([feat_local, feat_glob], dim=0)
                    cur_out = net_trans(feat_all)
                    cur_loss = self.loss_func(cur_out, labels)
                    cur_loss.backward()
                    grad_glob = net_global.get_activations_gradient()
                    grad_local = net.get_activations_gradient()
                    optimizer.step()

                    lamda = 1 - (num_updates / total_iters)
                    weights_local, cur_replace, cur_total = self.exchange(net_global, net, grad_glob, grad_local,
                                                                         lamda)
                    # cur_replace, cur_total = 1, 1
                    all_replace += cur_replace
                    all_total += cur_total
                    # weights_local = net.state_dict()
                    net.load_state_dict(weights_local)

                else:
                    net.zero_grad()
                    net_trans.zero_grad()
                    local_feats = net(images)
                    global_feats = net_global(images)
                    feats = torch.cat([local_feats, global_feats], dim=0)
                    log_probs = net_trans(feats)
                    loss = self.loss_func(log_probs, labels)
                    loss.backward()
                    optimizer.step()

                num_updates += 1
                batch_loss.append(loss.item())
                if num_updates == self.args.local_updates:
                    done = True
                    break
            epoch_loss.append(sum(batch_loss) / len(batch_loss))
            if done:
                break

            epoch_loss.append(sum(batch_loss) / len(batch_loss))
        # print('mask ratio:', mask_ratio, ' with ', mean_grad)
        all_replace = all_replace / local_eps
        all_total = all_total / local_eps
        return net.state_dict(), net_trans.state_dict(), \
               sum(epoch_loss) / len(epoch_loss), self.indd, all_replace, all_total