import math
from collections import defaultdict

import torch

from federatedscope.core.trainers import BaseTrainer
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer


def copy_params(src):
    tgt = dict()
    for name, t in src.named_parameters():
        if t.requires_grad:
            tgt[name] = t.detach().clone()
    return tgt


def prox_term(cur, last):
    loss = .0
    for name, w in cur.named_parameters():
        loss += 0.5 * torch.sum((w - last[name])**2)
    return loss


def add_noise(model, sigma):
    for p in model.parameters():
        if p.requires_grad:
            p.data += sigma * torch.randn(size=p.shape, device=p.device)


def moving_avg(cur, new, alpha):
    for k, v in cur.items():
        v.data = (1 - alpha) * v + alpha * new[k]


class LocalEntropyTrainer(BaseTrainer):
    def __init__(self, model, data, device, **kwargs):
        # NN modules
        self.model = model
        # FS `ClientData` or your own data
        self.data = data
        # Device name
        self.device = device
        # configs
        self.kwargs = kwargs
        self.config = kwargs['config']
        self.optim_config = self.config.train.optimizer
        self.local_entropy_config = self.config.trainer.local_entropy
        self._thermal = self.local_entropy_config.gamma

    def train(self):
        # Criterion & Optimizer
        criterion = torch.nn.CrossEntropyLoss().to(self.device)
        optimizer = get_optimizer(self.model, **self.optim_config)

        # _hook_on_fit_start_init
        self.model.to(self.device)
        current_global_model = copy_params(self.model)
        mu = copy_params(self.model)
        self.model.train()

        num_samples, total_loss = self.run_epoch(optimizer, criterion,
                                                 current_global_model, mu)
        for name, param in self.model.named_parameters():
            if name in mu:
                param.data = mu[name]

        # _hook_on_fit_end
        return num_samples, self.model.cpu().state_dict(), \
            {'loss_total': total_loss, 'avg_loss': total_loss/float(
                num_samples)}

    def run_epoch(self, optimizer, criterion, current_global_model, mu):
        running_loss = 0.0
        num_samples = 0
        # for inputs, targets in self.trainloader:
        for inputs, targets in self.data['train']:
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            # Descent Step
            optimizer.zero_grad()
            outputs = self.model(inputs)
            ce_loss = criterion(outputs, targets)
            loss = ce_loss + self._thermal * prox_term(self.model,
                                                       current_global_model)
            loss.backward()
            optimizer.step()

            # add noise for langevin dynamics
            add_noise(
                self.model,
                math.sqrt(self.optim_config.lr) *
                self.local_entropy_config.eps)

            # acc local updates
            moving_avg(mu, self.model.state_dict(),
                       self.local_entropy_config.alpha)

            with torch.no_grad():
                running_loss += targets.shape[0] * ce_loss.item()

            num_samples += targets.shape[0]
            self._thermal *= self.local_entropy_config.inc_factor

        return num_samples, running_loss

    def evaluate(self, target_data_split_name='test'):
        if target_data_split_name != 'test':
            return {}

        with torch.no_grad():
            criterion = torch.nn.CrossEntropyLoss().to(self.device)

            self.model.to(self.device)
            self.model.eval()
            total_loss = num_samples = num_corrects = 0
            # _hook_on_batch_start_init
            for x, y in self.data[target_data_split_name]:
                # _hook_on_batch_forward
                x, y = x.to(self.device), y.to(self.device)
                pred = self.model(x)
                loss = criterion(pred, y)
                cor = torch.sum(torch.argmax(pred, dim=-1).eq(y))

                # _hook_on_batch_end
                total_loss += loss.item() * y.shape[0]
                num_samples += y.shape[0]
                num_corrects += cor.item()

            # _hook_on_fit_end
            return {
                f'{target_data_split_name}_acc': float(num_corrects) /
                float(num_samples),
                f'{target_data_split_name}_loss': total_loss,
                f'{target_data_split_name}_total': num_samples,
                f'{target_data_split_name}_avg_loss': total_loss /
                float(num_samples)
            }

    def update(self, model_parameters, strict=False):
        self.model.load_state_dict(model_parameters, strict)

    def get_model_para(self):
        return self.model.cpu().state_dict()
