import copy
import random
from typing import Optional, List, Callable

import numpy as np
import torch
from hydra.utils import instantiate
from torchmetrics import CalibrationError

from src.util import set_net_alpha


def train(server_round, global_net, local_net, trainloader, cfg, client_alphas: Optional[List]):
    ece_fn = CalibrationError(task="multiclass", num_classes=cfg.dataset_model.num_classes)

    # Collect last global_model layer weights for FedProx OR Personalized FL
    previous_global_model_parameters = copy.deepcopy(list(global_net.named_parameters()))

    global_net.to(device=cfg.device).train()

    global_train_loss, global_train_acc, global_train_ece, global_ce_loss, global_prox_loss = _train(
        server_round,
        cfg,
        global_net,
        trainloader,
        client_alphas=client_alphas,
        ece_fn=ece_fn,
        previous_global_model_parameters=previous_global_model_parameters if cfg.strategy.strategy_name == "FedProx" else None,
        mu=cfg.dataset_model.get("fedprox_mu") if cfg.strategy.strategy_name == "FedProx" else None,  # TODO hacky
    )

    if cfg.strategy.local_model:
        local_net.to(device=cfg.device).train()
        local_train_loss, local_train_acc, local_train_ece, local_ce_loss, local_prox_loss = _train(
            server_round,
            cfg,
            local_net,
            trainloader,
            client_alphas=client_alphas,
            ece_fn=ece_fn,
            previous_global_model_parameters=previous_global_model_parameters,
            mu=cfg.strategy.local_prox_mu
        )
        return {
            "train_loss": local_train_loss,
            "train_acc": local_train_acc,
            "train_ece": local_train_ece,
            "ce_loss": local_ce_loss,
            "prox_loss": local_prox_loss
        }
    else:
        return {
            "train_loss": global_train_loss,
            "train_acc": global_train_acc,
            "train_ece": global_train_ece,
            "ce_loss": global_ce_loss,
            "prox_loss": global_prox_loss
        } 


def _train(
    server_round,
    cfg,
    net: torch.nn.Module,
    trainloader: torch.utils.data.DataLoader,
    client_alphas: Optional[List],
    ece_fn: Callable,
    previous_global_model_parameters=None,  # if supplied, trains proximal loss
    mu=None
):
    simplex_optimizer = getattr(torch.optim, cfg.dataset_model.optimizer)(net.parameters(), **cfg.dataset_model.opt_args)
    if hasattr(cfg.dataset_model, "lr_scheduler"):
        scheduler = instantiate(cfg.dataset_model.lr_scheduler, optimizer=simplex_optimizer)
        scheduler.step(server_round)

    for e in range(cfg.dataset_model.local_epochs):
        accumulated_ce_loss = 0.0  # only for reporting
        accumulated_prox_loss = 0.0  # only for reporting
        accumulated_loss = 0.0
        correct_samples = 0
        total_samples = 0
        pred_vector = []
        label_vector = []
        for data, target in trainloader:
            data, target = data.to(device=cfg.device), target.to(device=cfg.device)

            simplex_optimizer.zero_grad()
        
            if client_alphas is not None:
                sampled_alpha = client_alphas[np.random.choice(len(client_alphas))]
                set_net_alpha(net, sampled_alpha)
        
            output = net(data)
            loss = torch.nn.CrossEntropyLoss()(output, target)

            accumulated_ce_loss += loss.item()  # only for reporting

            if previous_global_model_parameters:
                proximal_loss = _compute_proximal_loss(net, previous_global_model_parameters, mu=mu)
                loss += proximal_loss
                accumulated_prox_loss += proximal_loss.item()  # only for reporting
            accumulated_loss += loss.item()

            if cfg.strategy.strategy_name == "FLOCO":
                if cfg.strategy.reg_hp != 0.0 and cfg.strategy.subspace_start <= server_round:
                    tmp_round_reg_hp = cfg.strategy.reg_hp / ((server_round - cfg.strategy.subspace_start) + 1 )
                    out = random.sample([i for i in range(cfg.rule.num_points)], 2)
                    i, j = out[0], out[1]
                    num = 0.0
                    normi = 0.0
                    normj = 0.0
                    for m in reversed(list(net.modules())):
                        if isinstance(m, torch.nn.ParameterList):
                            vi = m[i]
                            vj = m[j]
                            num += (vi * vj).sum()
                            normi += vi.pow(2).sum()
                            normj += vj.pow(2).sum()
                        reg_loss = cfg.strategy.reg_hp * (num.pow(2) / (normi * normj))
                        loss += reg_loss

            loss.backward()

            simplex_optimizer.step()

            pred = output.argmax(dim=1)
            correct_samples += torch.sum(pred == target).item()
            pred_vector.append(torch.softmax(output, dim=1).detach().cpu())
            label_vector.append(target.detach().cpu())
            total_samples += target.size(0)

    ce_loss = accumulated_ce_loss / cfg.dataset_model.local_epochs
    prox_loss = accumulated_prox_loss / cfg.dataset_model.local_epochs
    train_loss = accumulated_loss / cfg.dataset_model.local_epochs
    train_acc = correct_samples / total_samples
    train_ece = ece_fn(torch.cat(pred_vector), torch.cat(label_vector))

    return train_loss, train_acc, train_ece, ce_loss, prox_loss


def _compute_proximal_loss(net, previous_global_model_parameters, mu):
    """Calculate regularization term toward optimal global model."""
    for _, v in previous_global_model_parameters:
        v.requires_grad = False
    proximal_term = 0.0
    for (n, w), (n_t, w_t) in zip(net.named_parameters(), previous_global_model_parameters):
        proximal_term += (w - w_t).norm(2)
    proximal_loss = mu * proximal_term
    return proximal_loss
