import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
from model import LossWeightLearner

from utils.common import select_prob
from utils.evaluation import (evaluate_net_reg_Metric_logger,
                              evaluation_logPrint)
from utils.loss import (gaussian_smooth_kl_loss, partial_label_loss)
from utils.optim_utils import fea_extractor_bn_setup, grad_update_switch

txt_logger = logging.getLogger("sfda_reg")

# for feature norm
def ema_update_full_safe(net_ref, net, alpha=0.999):
    with torch.no_grad():
        ref_params = dict(net_ref.named_parameters())
        net_params = dict(net.named_parameters())

        for name, param_ref in ref_params.items():
            if name in net_params:
                param = net_params[name]
                param_ref.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

        ref_buffers = dict(net_ref.named_buffers())
        net_buffers = dict(net.named_buffers())

        for name, buffer_ref in ref_buffers.items():
            if name in net_buffers:
                buffer = net_buffers[name]
                if torch.is_floating_point(buffer_ref):
                    buffer_ref.data.mul_(alpha).add_(buffer.data, alpha=1 - alpha)
             


def get_point_weight_batch(y_hist, batch_idx, loss_config):
    mode = loss_config['cls_weight']
    if mode:
        weights = y_hist.point_weight[batch_idx].cuda()
        if mode == 'norm':
            return weights * batch_idx.size(0) / weights.sum()
        else:
            return weights
    else:
        return torch.ones(len(batch_idx)).cuda()



def train_module_m(
        net=None,
        net_ref=None,
        opt=None,
        scheduler=None,
        module_state=None,
        bn_mode=None,
        tr_dl=None,
        val_dl=None,
        tr_dl_name="VAL",
        val_dl_name="VAL",
        train_m_config=None,
        train_idx=0,
        y_hist=None,
        tau=0.68):

    losses = train_m_config['losses']
    epoch = train_m_config["epoch"]
    
    for e in range(epoch):

        net.train()
        net_ref.eval()
        grad_update_switch(net, True, module_state["train"], e=e)
        grad_update_switch(net, False, module_state["frozen"], e=e)
        grad_update_switch(net_ref, False, e=e)
        fea_extractor_bn_setup(net.get_feature_extractor(), bn_mode, e=e)

        if (coef_type := train_m_config["coef_type"]) == 'auto':
            weight_learner = LossWeightLearner(
                num_losses=len(losses), init_value=0.0).cuda()
            weight_opt = torch.optim.Adam(
                weight_learner.parameters(), lr=train_m_config["coef_lr"])
        else:
            weight_learner = None
            weight_opt = None
            coef_t = train_m_config['coef']
            
        if weight_learner is not None:
            weight_learner.train()


        y_cls_value = y_hist.y_cls_value.cuda()
        
        for i, batch in enumerate(tr_dl):
            total_loss = torch.tensor(0.0).cuda()
            opt.zero_grad()
            x, _, idx = batch
            x = x.cuda()
            if x.size(0) == 1:
                continue

            loss_dict = {}
            loss_weight_dict = {}
            loss_list = []
            batch_loss_str = '[loss] '

            feature = net.feature(x)
            cls_feature = net.cls_feature_from_feature(feature)
            y_pred_cls_batch = net.predict_classification_from_cls_feature(
                cls_feature)
            y_pred_reg_batch = net.predict_from_feature(feature)

            for l_term in losses:
                if l_term == 'partial_loss':
                    point_weight_batch = get_point_weight_batch(
                        y_hist, idx, train_m_config["partial_loss"])
                    smooth = train_m_config["partial_loss"].get("smooth", -1)
                    partial_label_batch = y_hist.y_pred_partial_cls[idx].cuda()
                    cls_partial_label_loss = partial_label_loss(
                        y_pred_cls_batch,
                        partial_label_batch,
                        smooth=smooth,
                        mask=point_weight_batch)
                    loss_dict['partial_loss'] = cls_partial_label_loss.item()
                    loss_list.append(cls_partial_label_loss)
                    batch_loss_str += f"partial: {cls_partial_label_loss.item():.3f} | "
                    
                if l_term == 'kl_loss':
                    sigma_div = train_m_config["kl_loss"]['sigma_div']
                    loss_mode = train_m_config["kl_loss"].get(
                        "loss_mode", 'kl')
                    point_weight_batch = get_point_weight_batch(
                        y_hist, idx, train_m_config["kl_loss"])
                    batch_cls_pred_prob_bank = y_hist.y_pred_partial_cls_ratio[idx]
                    _, pl_batch = torch.max(batch_cls_pred_prob_bank, dim=-1)
                    pl_batch = pl_batch.cuda()
                    kl_loss = gaussian_smooth_kl_loss(
                        y_pred_cls_batch, pl_batch, idx, y_hist, point_weight_batch, sigma_div, loss_mode)
                    loss_dict['kl_loss'] = kl_loss.item()
                    loss_list.append(kl_loss)
                    batch_loss_str += f"kl: {kl_loss.item():.3f} | "
                
                if l_term == 'feaNorm_loss':
                    feature_ref = net_ref.feature(x)
                    feature_norm = torch.norm(feature, p=2, dim=1)
                    feature_ref_norm = torch.norm(feature_ref, p=2, dim=1)
                    norm_difference = feature_norm - feature_ref_norm
                    feaNorm_loss = torch.mean(norm_difference**2)
                    loss_list.append(feaNorm_loss)
                    batch_loss_str += f"feaNorm: {feaNorm_loss.item():.3f} | "
                    loss_dict['feaNorm_loss'] = feaNorm_loss.item()
                    
                if l_term == 'reg_sum_loss':
                    tau = train_m_config["reg_sum_loss"].get("tau", 0.68)
                    reg_loss_type = train_m_config["reg_sum_loss"].get(
                        "reg_loss_type", 'mse')
                    batch_cls_pred_prob = y_hist.y_pred_partial_cls_ratio[idx]
                    sum_prob = select_prob(
                        batch_cls_pred_prob, threshold=tau)
                    sum_reg = (sum_prob @ y_cls_value) / \
                        sum_prob.sum(-1)  # average
                    if reg_loss_type == 'rmse':
                        reg_sum_mse_loss = torch.sqrt(F.mse_loss(
                            y_pred_reg_batch, sum_reg, reduction='mean') + 1.0e-8)
                    else:
                        reg_sum_mse_loss = F.mse_loss(
                            y_pred_reg_batch, sum_reg)
                    loss_list.append(reg_sum_mse_loss)
                    loss_dict['reg_sum_loss'] = reg_sum_mse_loss.item()
                    batch_loss_str += f"reg({tau}) - [{reg_loss_type}]: {reg_sum_mse_loss.item():.3f} | "
                    

            if weight_learner is None:
                for l_id, l in enumerate(loss_list):
                    total_loss += l * coef_t[l_id]
                    loss_weight_dict[losses[l_id]] = coef_t[l_id]
                weights_log = coef_t
            else:
                total_loss, weights = weight_learner.weighted_loss(*loss_list)
                for l_id, l_name in enumerate(losses):
                    weights_l = weights.cpu().tolist()
                    loss_weight_dict[l_name] = weights_l[l_id]
                weights_log = weights_l
                
            total_loss.backward()
            opt.step()
            if weight_opt is not None:
                weight_opt.step()


        txt_logger.info(
            f"Last batch info:\n{batch_loss_str}[weight] [{coef_type}] {losses} - {weights_log}")

        metric_logger_dict, net_pred_dict = evaluate_net_reg_Metric_logger(
            net, val_dl, y_hist, info='',  return_net_pred=True)
        evaluation_logPrint(metric_logger_dict, '')

        if update_y_mode := train_m_config.get('update_y_mode', False):
            update_y_mv_coef = train_m_config.get('update_y_mv_coef', 0.5)
            y_hist.update_y(net_pred_dict, update_y_mv_coef)
        if "feaNorm_loss" in losses:
            ema_update_full_safe(net_ref, net, alpha=0.8)
            txt_logger.info(
                f"reference network has been partially EMA-updated by main model at EPO [{e}]")

    return


def train_hist(
    net=None,
    net_ref=None,
    y_hist=None,
    opt=None,
    schedulers=None,
    module_state=None,
    bn_mode=None,
    tr_dl=None,
    val_dl=None,
    tr_dl_name="VAL",
    val_dl_name="VAL",
    train_config=None,
):
    order = train_config["order"]
    # evaluate net before training
    
    tau = 0.68
    for train_m in order:
        losses = train_config[train_m]['losses']
        if "reg_sum_loss" in losses:
            tau = train_config[train_m]["reg_sum_loss"].get("tau", 0.68)

    info_str = f"[{val_dl_name}] - Before training "
    metric_logger_dict = evaluate_net_reg_Metric_logger(
        net, val_dl, y_hist, info=info_str, return_net_pred=False)
    evaluation_logPrint(metric_logger_dict, info_str)

    # training by order - evaluate at the end of each training
    for train_idx, train_m in enumerate(order):

        model_config = {
            'net': net,
            'net_ref': net_ref,
        }
        opt_config = {
            'opt': opt[train_m],
            'scheduler': schedulers[train_m],
            'module_state': module_state[train_m],
            'bn_mode': bn_mode[train_m]
        }
        dl_config = {
            'tr_dl': tr_dl,
            'tr_dl_name': tr_dl_name,
            'val_dl': val_dl,
            'val_dl_name': val_dl_name,
        }
        others = {
            "y_hist": y_hist,
            "train_m_config": train_config[train_m],
            'train_idx': train_idx,
            'tau': tau
        }

        train_module_m(**model_config, **opt_config, **dl_config, **others)

    return
