'''
reference: https://github.com/seungjoo-ai/FLFL-NeurIPS24
'''

import os
import copy
import time
import torch
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import confusion_matrix
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader
from algorithm.base import ClientBase, ServerBase
from collections import defaultdict

class FL2(ServerBase):
    def __init__(self,args):
        super().__init__(args, Client)

    def training_stats(self, round_idx):
        pass
        
    def aggregate_models(self, round_idx):
        """
        get local models from selected clients and aggregate them
        """
        uploaded_models = []
        weights = []
        for i, id in enumerate(self.selected_clients):
            client = self.clients[id]
            
            if client.util > 0: # local training happens
                model = copy.deepcopy(client.model).to(self.device)
                uploaded_models.append(model)
                if self.agg == 'uniform':
                    weights.append(1)
                elif self.agg == 'weighted':
                    weights.append(client.util)
                elif self.agg == 'lsa': # learning status aware
                    weights.append(1 - client.learning_status)
                else:
                    raise ValueError(f'invalid aggregation method: {self.agg}')
        wsum = sum(weights)
        weights = [i / wsum for i in weights]
        self.aggregate(uploaded_models, weights)


class ASAM:
    def __init__(self, model, rho=0.5, eta=0.01):

        self.model = model
        self.rho = rho
        self.eta = eta
        self.state = defaultdict(dict)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        wgrads = []
        for n, p in self.model.named_parameters():
            if p.grad is None:
                continue
            t_w = self.state[p].get("eps")
            if t_w is None:
                t_w = torch.clone(p).detach()
                self.state[p]["eps"] = t_w
            if 'weight' in n:
                t_w[...] = p[...]
                t_w.abs_().add_(self.eta)
                p.grad.mul_(t_w)
            wgrads.append(torch.norm(p.grad, p=2))
        wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16
        for n, p in self.model.named_parameters():
            if p.grad is None:
                continue
            t_w = self.state[p].get("eps")
            if 'weight' in n:
                p.grad.mul_(t_w)
            eps = t_w
            eps[...] = p.grad[...] # type: ignore
            eps.mul_(self.rho / wgrad_norm) # type: ignore
            p.add_(eps)

        if zero_grad:
            self.model.zero_grad()

    @torch.no_grad()
    def second_step(self):
        # reserve the original parameters
        for n, p in self.model.named_parameters():
            if "eps" not in self.state[p]: #p.grad is None or p not in self.state:
                continue
            p.sub_(self.state[p]["eps"])

class SAM(optim.Optimizer):
    def __init__(self, model, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        kwargs['lr'] = 0.001 #dummy value
        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        params = model.parameters()
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self):
        for group in self.param_groups:
            for p in group["params"]:
                if "old_p" not in self.state[p]: continue
                p.data = self.state[p]["old_p"]

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

def kldiv_with_mask(logits_s, logits_w, mask, temp=1.0):
    logits_s = F.log_softmax(logits_s / temp , dim=1)
    logits_w = F.softmax(logits_w.detach(), dim=1)

    if mask is None:
        kl_loss = F.kl_div(logits_s, logits_w, reduction='batchmean')
    else:
        kl_loss = F.kl_div(logits_s, logits_w, reduction='none')
        kl_loss = kl_loss * mask.unsqueeze(dim=-1).repeat(1, logits_s.shape[1])
        kl_loss = kl_loss.sum(dim=1).mean()

    return kl_loss


class Client(ClientBase):
    def __init__(self, args, id, trainset):
        super().__init__(args, id, trainset)
        self.sam_opt = args.sam
        self.rho = args.rho
        self.learning_status = 0.1
        self.sam_optimizer = self.get_sam_optimizer()


    def get_sam_optimizer(self):
        if self.sam_opt == 'sam':
            return SAM(self.model, optim.SGD, rho=self.rho)
        elif self.sam_opt == 'asam':
            return ASAM(self.model, rho=self.rho)
        else:
            raise NotImplementedError


    def make_pseudo_lb_dataset(self):

        data_loader = DataLoader(self.trainset, batch_size=256, shuffle=False)
        self.model.train(False)
        model = self.model
        logits, labels, all_idx = [], [], []

        with torch.no_grad():
            for i,data in enumerate(data_loader):
                idx, xw, y = data['idx'], data['x'], data['y']
                all_idx.append(idx)
                xw, y = xw.to(self.device), y.to(self.device)
                labels.append(y)
                logit = model(xw)
                logits.append(logit)
            idx = torch.cat(all_idx, dim=0)
            logits = torch.cat(logits, dim=0)
            labels = torch.cat(labels, dim=0)
            logits = F.softmax(logits, dim=1)
            max_prob, pseudo_labels = logits.max(dim=1)

            global_t = max_prob.mean()
            local_t = logits.mean(dim=0)
            self.threshold = global_t
            mod = local_t/torch.max(local_t, dim=-1)[0]
            mask = max_prob.ge(global_t*mod[pseudo_labels])
            fix_mask = max_prob.ge(self.threshold)
            self.learning_status = global_t

        stat_classes = self.trainset.classes
        cm = confusion_matrix(labels.cpu().numpy(), pseudo_labels.cpu().numpy(), labels=range(stat_classes))
        self.cm = cm
        if torch.any(mask):
            mask_cm = confusion_matrix(labels[mask].cpu().numpy(), pseudo_labels[mask].cpu().numpy(), labels=range(stat_classes))
            self.mask_cm = mask_cm
        else:
            self.mask_cm = np.zeros((stat_classes, stat_classes), dtype=int)
        free_dataset = copy.deepcopy(self.trainset)
        free_dataset.pseudo_labels = pseudo_labels.tolist()
        free_dataset.mask = mask
        free_dataset.fix_mask = fix_mask

        return free_dataset

    def train(self, round_idx, lr, state_dict):

        self.prepare(lr, state_dict)
        free_set= self.make_pseudo_lb_dataset()
        if free_set is not None :
            free_loader = DataLoader(free_set, batch_size=self.batch_size)
            for step in range(self.local_steps):
                self.model.train(True)
                for i, data in enumerate(free_loader):
                    self.train_one_batch(data)
        self.util = len(free_set) if free_set is not None else 0
        self.optimizer_dict = self.optimizer.state_dict()
        self.model.to('cpu')


    def train_one_batch(self, data):

        sam = self.sam_optimizer
        model =self.model
        model.zero_grad()
        data_s, data_y = data['x_s'].to(self.device), data['py'].to(self.device)
        l = len(data_y)
        mask, fix_mask = data['m'].to(self.device), data['fm'].to(self.device)
        both_mask = torch.logical_and(mask,fix_mask)
        kl_grad = None
        if any(both_mask):
            with torch.no_grad():
                logitsa = model(data_s)
            datay_temp = data_y[both_mask]
            datas_temp = data_s[both_mask]
            logits1 = model(datas_temp)
            loss1 = F.cross_entropy(logits1, datay_temp)
            loss1.backward()
            sam.first_step(zero_grad=True)
            logitsb = model(data_s)
            kl_loss = kldiv_with_mask(logitsb, logitsa, fix_mask) #*self.weight_a/l*l1
            kl_grad = torch.autograd.grad(kl_loss, model.parameters()) # type: ignore
            sam.second_step()
            model.zero_grad()
        if any(mask):
            datay_temp2 = data_y[mask]
            datas_temp2 = data_s[mask]
            logits2 = model(datas_temp2)
            loss2 = F.cross_entropy(logits2, datay_temp2) #*self.weight_b/l*l2
            loss2.backward()
        if kl_grad is not None:
            for param, grad in zip(model.parameters(), kl_grad):
                param.grad += grad
        if self.clip_grad > 0:
            clip_grad_norm_(self.model.parameters(), self.clip_grad)
        self.optimizer.step()




