import time
import torch
import torch.nn.functional as F
import gc

#######################################################################
# Per-epoch Train/Test Entry Point
#######################################################################

def train_epoch(exp):
    # Validation
    exp.r.moe_tracker.check()
    # Setup alias
    device = exp.r.device
    gating = exp.r.moe_gates
    # Init
    _ep_report = {}
    _ep_report['t1'] = time.time()

    # Begin mini-batch training
    gating = gating.to(device)
    gating.train()
    for _m in exp.r.all_models:
        _m = _m.to(device)
        _m.train()
    _ep_report = _train_minibatch(exp, _ep_report)

    gc.collect()
    return _ep_report

#######################################################################
# Active Training Loop
#######################################################################

def _train_minibatch(exp, _ep_report):
    """
    Shape of key variables:
        `r`: random noise | [B, M] = [256, 4]
        `X`: batch input sample | [B, L+1, d] = [256, 65, 48]
        `m`/`m_log`: selected expert ID | [B] = [256]
        `pi`: soft routing distribution | [B, M] = [256, M]
    """

    # Setup alias
    device = exp.r.device
    trainloader = exp.r.train_loader
    gating = exp.r.moe_gates
    optimizer = exp.r.fg_optimizer
    C = exp.ds.n_tot_class
    ROUTER_LAMBDA = exp.ds.lr_param['lambda_r']

    total_loss = 0
    total_loss_no_entropy = 0
    total_expert_loss = 0
    total_route_loss = 0
    total_dataCnt = 0
    total_correct = 0
    total_encoder_loss = 0
    for _b_idx, (inputs, labels) in enumerate(trainloader):
        B = len(inputs)

        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        # forward through gating
        if exp.r.roundID >= exp.ds.vit_param['fg_round']:
            r = torch.zeros(len(inputs), exp.args.n_expert).to(device)
        else:
            r = torch.rand(len(inputs), exp.args.n_expert).to(device) * exp.ds.t_param['gating_r']
        X, m_log, pi = gating(inputs, r)
        if exp.args.dataset=='cifar_100' or exp.r.roundID >= exp.ds.vit_param['fg_round']:
            ENTROPY_LAMBDA = 0
            entropy_loss = - (pi * torch.log(pi + 1e-8)).sum(dim=1).mean()
        else:
            ENTROPY_LAMBDA = exp.ds.vit_param['lambda_e']
            entropy_loss = - (pi * torch.log(pi + 1e-8)).sum(dim=1).mean()
        expert_loss = torch.tensor(0.0, device=device)
        encoder_loss = torch.tensor(0.0, device=device)
        pred_logits = torch.zeros(B, C, device=device)
        # BEGIN checking all experts (for efficiency)
        for m, model_m in enumerate(exp.r.all_models):
            routed_idx = (m_log == m) # [B] = [256], boolean
            if routed_idx.sum() == 0:
                continue  # skip expert m if no sample routed to it
            X_m = X[routed_idx] # [B_m, L+1, d]
            labels_m = labels[routed_idx] # [B_m]
            encoder_out,_ = model_m.encoder(X_m) # [B_m, L+1, d]
            cls_token = encoder_out[:, 0]
            y_hats_m = model_m.classifier(cls_token)
            loss_m = F.cross_entropy(y_hats_m, labels_m, reduction='mean')
            expert_loss += loss_m * routed_idx.sum()
            with torch.no_grad():
                detached_logits = model_m.classifier(cls_token)
            encoder_only_loss_m = F.cross_entropy(detached_logits, labels_m, reduction='mean')
            encoder_loss += encoder_only_loss_m * routed_idx.sum()
            pred_logits[routed_idx] = y_hats_m.detach()
        # END checking all experts
        # compute final expert loss as weighted sum
        expert_loss = expert_loss / B
        encoder_loss = encoder_loss / B
        total_encoder_loss += encoder_loss.item() * B
        # compute router loss
        router_weights = pi[torch.arange(B), m_log]  # [B]
        with torch.no_grad():
            expert_logits = pred_logits.detach()
        router_ce_loss = F.cross_entropy(expert_logits, labels, reduction='none')  # [B]
        router_loss = (router_weights * router_ce_loss).mean()
        total_expert_loss += expert_loss.item() * B
        total_route_loss += router_loss.item() * B
        total_loss_no_entropy += expert_loss.item()*B + router_loss.item()*B*ROUTER_LAMBDA
        # compute final loss
        batch_loss = expert_loss + ROUTER_LAMBDA * router_loss + ENTROPY_LAMBDA * entropy_loss
        batch_loss.backward()
        if exp.r.roundID<50:
            exp.clip_gradients()

        optimizer.step()

        total_loss += batch_loss.item() * B
        pred_labels = pred_logits.argmax(dim=1)
        total_dataCnt += B
        total_correct += (pred_labels == labels).sum().item()
    
    _ep_report['t2'] = time.time()
    _ep_report['tr_loss'] = total_loss / total_dataCnt
    _ep_report['tr_acc'] = total_correct / total_dataCnt * 100
    return _ep_report
