import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
import utils.utils
import wandb


def compute_loss(cls_criterion, preds, targets, score_list=None, lambda_sqrt_inv_list=None, kl_list=None, eta_ksvd=1., eta_kl=1.):
    loss_ce = cls_criterion(preds, targets)

    if score_list is not None:
        loss_ksvd = 0
        loss_kl = 0
        for i in range(len(score_list)):
            # loss_ksvd
            loss_escore = torch.mean((torch.einsum('...nd,...ds->...ns', score_list[i][0], lambda_sqrt_inv_list[i].unsqueeze(0))).norm(dim=-1, p=2)**2)/2
            loss_rscore = torch.mean((torch.einsum('...nd,...ds->...ns', score_list[i][1], lambda_sqrt_inv_list[i].unsqueeze(0))).norm(dim=-1, p=2)**2)/2
            loss_trace = torch.einsum('...ps,...pd->...sd', score_list[i][2], score_list[i][3].type_as(score_list[i][2])).mean(dim=0).trace()
            loss_ksvd = loss_ksvd + (loss_escore + loss_rscore - loss_trace) ** 2
            # loss_kl
            loss_kl = loss_kl + kl_list[i]

        # add loss_kl
        loss_total = loss_ce + eta_ksvd * loss_ksvd + eta_kl * loss_kl 

        return loss_total, loss_ce, loss_ksvd, loss_kl
    else:
        return loss_ce
    
def compute_loss_diffusion(args, mse_criterion, means_from_diffusion, means_x_minus, stds_from_diffusion, covariances_x_minus):
    """
    Compute the total loss as the sum of MSE losses between Diffusion and ViT outputs.
    
    Parameters:
        mse_criterion (nn.Module): MSE loss function.
        diffusion_layer_outputs (list of tensors): Sampled outputs from Diffusion model layers.
        vit_layer_outputs (list of tensors): Outputs from ViT model layers.
    
    Returns:
        total_loss (Tensor): Sum of MSE losses across all layers.
        layer_losses (dict): Dictionary of individual layer MSE losses.
    """
    means_mse = 0
    stds_mse = 0

    for layer_idx, (mean_diff_out, mean_vit_out) in enumerate(zip(means_from_diffusion, means_x_minus)):
        # Compute MSE loss between Diffusion output and ViT output
        mean_loss = mse_criterion(mean_diff_out, mean_vit_out)
        means_mse += mean_loss / len(means_from_diffusion)
    
    for layer_idx, (std_diff_out, cov_vit_out) in enumerate(zip(stds_from_diffusion, covariances_x_minus)):
        # Compute MSE loss between Diffusion output and ViT output
        # if args.attn_type == 'softmax':
        #     break
        # else:
        #     if args.depth == args.ksvd_layers:
        #         std_loss = mse_criterion(std_diff_out, cov_vit_out)
        #         stds_mse += std_loss #/ len(stds_from_diffusion)
        #     else: 
        #         if layer_idx < (args.depth - args.ksvd_layers):
        #             continue
        #         else:
        #             std_loss = mse_criterion(std_diff_out, cov_vit_out)
        #             stds_mse += std_loss
        std_loss = mse_criterion(std_diff_out, cov_vit_out)
        stds_mse += std_loss / len(stds_from_diffusion)
    
    return means_mse, stds_mse

def train(train_loader, net, optimizer, epoch, logger, args):

    net.train()

    # define criterion
    cls_criterion = nn.CrossEntropyLoss()

    if args.attn_type == "softmax":
        train_log = {
            'MCC' : utils.utils.AverageMeter(),
            'Top1 Acc.' : utils.utils.AverageMeter(),
            'Tot. Loss' : utils.utils.AverageMeter(),
            'LR' : utils.utils.AverageMeter(),
            }
    elif args.attn_type == "kep_svgp":
        train_log = {
            'MCC' : utils.utils.AverageMeter(),
            'Top1 Acc.' : utils.utils.AverageMeter(),
            'CE Loss' : utils.utils.AverageMeter(),
            'KSVD Loss' : utils.utils.AverageMeter(),
            'KL Loss' : utils.utils.AverageMeter(),
            'Tot. Loss' : utils.utils.AverageMeter(),
            'LR' : utils.utils.AverageMeter(),
            }
    elif args.attn_type == "sgpa":
        train_log = {
            'Tot. Loss': utils.utils.AverageMeter(),
            'LR': utils.utils.AverageMeter(),
        }

    msg = '####### --- Training Epoch {:d} --- #######'.format(epoch)
    logger.info(msg)

    for i in range(train_loader.num_batches):
        data, inputs, inputs_mask, positional, answers = train_loader.__load_next__()
        inputs = inputs.to(f'cuda:{args.gpu}')
        inputs_mask = inputs_mask.to(f'cuda:{args.gpu}')
        positional = positional.to(f'cuda:{args.gpu}')
        answers = answers.to(f'cuda:{args.gpu}')

        optimizer.zero_grad()
        
        if args.attn_type == 'sgpa':
            loss = net.loss(inputs, answers, positional, inputs_mask, data, anneal_kl=min(1.0, epoch * 2 / args.nb_epochs))
        else:
            outs = net(inputs, positional, inputs_mask, data)

        if args.attn_type == "softmax":
            loss = compute_loss(cls_criterion, outs, answers)
        elif args.attn_type == "kep_svgp":
            loss, loss_ce, loss_ksvd, loss_kl = compute_loss(cls_criterion, outs[0], answers, \
                                                            outs[1], outs[2], outs[3], args.eta_ksvd, args.eta_kl)

        loss.backward()
        optimizer.step()

        if args.attn_type == "softmax":
            prec, _ = utils.utils.accuracy(outs, answers)
            mcc = utils.utils.mcc(outs, answers)
        elif args.attn_type == "kep_svgp":
            prec, _ = utils.utils.accuracy(outs[0], answers)
            mcc = utils.utils.mcc(outs[0], answers)

        for param_group in optimizer.param_groups:
            lr = param_group["lr"]
            break

        
        train_log['Tot. Loss'].update(loss.item(), inputs.size(0))
        train_log['LR'].update(lr, inputs.size(0))
        if args.attn_type == "softmax" or args.attn_type == 'kep_svgp':
            train_log['MCC'].update(mcc, inputs.size(0))
            train_log['Top1 Acc.'].update(prec.item(), inputs.size(0))
        if args.attn_type == "kep_svgp":
            train_log['CE Loss'].update(loss_ce.item(), inputs.size(0))
            train_log['KSVD Loss'].update(loss_ksvd.item(), inputs.size(0))
            train_log['KL Loss'].update(loss_kl.item(), inputs.size(0))

        if i % 100 == 99:
            log = ['LR : {:.5f}'.format(train_log['LR'].avg)] + [key + ': {:.2f}'.format(train_log[key].avg) for key in train_log if key != 'LR']
            msg = 'Epoch {:d} \t Batch {:d}\t'.format(epoch, i) + '\t'.join(log)
            logger.info(msg)
            for key in train_log : 
                train_log[key] = utils.utils.AverageMeter()

    # Replace writer.add_scalar with wandb.log
    wandb.log({f"Train/{key}": train_log[key].avg for key in train_log}, step=epoch)

def train_diffusion(train_loader, diffusion_model, optimizer, epoch, logger, args, vit_model):
    """
    Train the Diffusion model by aligning its layers with the ViT model's layers using MSE loss.
    
    Parameters:
        train_loader (DataLoader): Training data loader.
        diffusion_model (nn.Module): Diffusion model to be trained.
        optimizer (Optimizer): Optimizer for the Diffusion model.
        epoch (int): Current epoch number.
        logger (Logger): Logger for logging information.
        args (Namespace): Command-line arguments.
        vit_model (nn.Module): Pre-trained ViT model for layer alignment.
    """
    diffusion_model.train()
    vit_model.eval()  # Ensure ViT is in evaluation mode

    # Freeze ViT model parameters
    # for param in diffusion_model.parameters():
    #     param.requires_grad = False

    for param in vit_model.parameters():
        param.requires_grad = False

    # for param in vit_model.fc.parameters():
    #     param.requires_grad = True

    # Define loss function
    mse_criterion = nn.MSELoss() #to be uncomment

    ce_criterion = nn.CrossEntropyLoss()

    # Initialize training logs
    train_log = {
        'CE Loss': utils.utils.AverageMeter(),
        'Mean Loss': utils.utils.AverageMeter(),
        'Var Loss': utils.utils.AverageMeter(),
        'Tot. Loss': utils.utils.AverageMeter(),
        'LR': utils.utils.AverageMeter(),
    }

    msg = '####### --- Training Epoch {:d} --- #######'.format(epoch)
    logger.info(msg)

    for i in range(train_loader.num_batches):
        data, inputs, inputs_mask, positional, answers = train_loader.__load_next__()
        inputs = inputs.to(f'cuda:{args.gpu}')
        inputs_mask = inputs_mask.to(f'cuda:{args.gpu}')
        positional = positional.to(f'cuda:{args.gpu}')
        answers = answers.to(f'cuda:{args.gpu}')

        optimizer.zero_grad()
        outs = diffusion_model(inputs, positional, data)
        ce_loss = compute_loss(ce_criterion, outs, answers)
        with torch.no_grad(): #to be uncomment
            _, x_t_from_ViT, means_x_minus, covariances_x_minus = vit_model(inputs, positional, inputs_mask, data)
        means_from_diffusion, stds_from_diffusion = diffusion_model(x_t_from_ViT, train=True)
        # print(x_t_from_diffusion[0].shape) # for debug only
        # print(means_x_minus[0].shape) # for debug only
        means_loss, stds_loss = compute_loss_diffusion(args, mse_criterion, means_from_diffusion, means_x_minus, stds_from_diffusion, covariances_x_minus)#to be uncomment
        loss = args.lambda_mean*means_loss + args.lambda_var*stds_loss + args.lambda_ce*ce_loss
        loss.backward()
        optimizer.step()

        for param_group in optimizer.param_groups:
            lr = param_group["lr"]
            break
        
        train_log['CE Loss'].update(ce_loss.item(), inputs.size(0))
        train_log['Mean Loss'].update(means_loss.item(), inputs.size(0))
        train_log['Var Loss'].update(stds_loss.item(), inputs.size(0))
        train_log['Tot. Loss'].update(loss.item(), inputs.size(0))
        train_log['LR'].update(lr, inputs.size(0))

        if i % 100 == 99:
            log = ['LR : {:.5f}'.format(train_log['LR'].avg)] + [
                key + ': {:.2f}'.format(train_log[key].avg) for key in train_log if key != 'LR'
            ]
            msg = 'Epoch {:d} \t Batch {:d}\t'.format(epoch, i) + '\t'.join(log)
            logger.info(msg)
            for key in train_log:
                train_log[key] = utils.utils.AverageMeter()

    # Replace writer.add_scalar with wandb.log
    wandb.log({f"Train/{key}": train_log[key].avg for key in train_log}, step=epoch)