import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
from utils import Bar, AverageMeter, RunningMeter
import loss
import models

def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))

def get_current_consistency_weight(epoch, consistency_rampup):
        return sigmoid_rampup(epoch, consistency_rampup)

## Top-1, Top-5 accuracy
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        tot_correct = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0)
            tot_correct.append(correct_k.mul_(100.0 / batch_size))
        return tot_correct

def OkdTrain(loader,
             model,
             loss_fn,
             loss_stu,
             loss_kl,
             optimizer,
             epoch,
             comp,
             device,
             num_branches,
             consistency_rampup,
             margin
             ):
    model.train()
    
    accTop1_avg = list(range(num_branches+2)) # branches, ... , stu, en_teacher
    accTop5_avg = list(range(num_branches+2)) # branches, ... , stu, en_teacher
    for i in range(num_branches+2):
        accTop1_avg[i] = AverageMeter()
        accTop5_avg[i] = AverageMeter()
    losses = AverageMeter() # data loss
    losses_kd = AverageMeter() # kd loss
    losses_s = AverageMeter() # stu loss
    dist_avg = RunningMeter()

    bar = Bar('Processing', max=len(loader))
    consistency_weight = get_current_consistency_weight(epoch, consistency_rampup)
    for batch_idx, (inputs, targets) in enumerate(loader):
        if device:
            inputs, targets = inputs.cuda(device), targets.cuda(device)

        outputs, out_s, _ = model(inputs)

        loss_cross = 0
        loss_s = 0
        loss_kd = 0

        ## Cross-entropy loss (1. Data loss)
        for i in range(num_branches): # All branches, masking data loss
            loss_cross += loss_fn[i](outputs[:,:,i], targets)
        loss_s = loss_stu(out_s, targets) # student, original data loss
        loss_cross += loss_s
                
        ## Compensating posteriors and ensemble prediction
        with torch.no_grad():
            route_outputs = outputs.clone().detach()
            out_t = torch.log(torch.mean(F.softmax(route_outputs-comp.expand(route_outputs.size(0), comp.size(0), -1), dim=1), dim=2))

        ## Hinton KD loss (2. Hinton Distillation)
        loss_kd = loss_kl(out_s, out_t, en=False)
        
        ## Define total loss
        loss = loss_cross + consistency_weight * loss_kd
        losses.update(loss.data.item(), inputs.size(0))
        losses_s.update(loss_s.data.item(), inputs.size(0))
        losses_kd.update(loss_kd.data.item(), inputs.size(0))

        for i in range(num_branches):
            metrics = accuracy(outputs[:,:,i], targets, topk=(1, 5))
            accTop1_avg[i].update(metrics[0].item(), inputs.size(0))
            accTop5_avg[i].update(metrics[1].item(), inputs.size(0))

        metrics = accuracy(out_s, targets, topk=(1, 5))
        accTop1_avg[num_branches].update(metrics[0].item(), inputs.size(0))
        accTop5_avg[num_branches].update(metrics[1].item(), inputs.size(0))

        metrics = accuracy(out_t, targets, topk=(1, 5))
        accTop1_avg[num_branches+1].update(metrics[0].item(), inputs.size(0))
        accTop5_avg[num_branches+1].update(metrics[1].item(), inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        outputs = F.softmax(outputs, dim=1).permute(0,2,1)
        sim = torch.cdist(outputs, outputs, p=2)
        sim = sim.masked_select(~(torch.diag_embed(torch.diagonal(sim, offset=0, dim1=-2, dim2=-1)+1e3, offset=0, dim1=-2, dim2=-1).bool())).view(sim.size(0), num_branches, -1).mean(2).mean(1).sum()
        dist_avg.update2(sim.item(), outputs.size(0)) # averaged distance for all samples

        show_metrics = {}
        show_metrics.update({'Loss_stu': losses_s.avg, 'Dist':dist_avg.value()})
        
        for i in range(num_branches):
            show_metrics.update({'Top1_C'+str(i): accTop1_avg[i].avg})
        
        show_metrics.update({'Top1_stu': accTop1_avg[num_branches].avg})
        show_metrics.update({'Top1_t': accTop1_avg[num_branches+1].avg})
        bar.suffix = " | ".join("{}: {:.4f}".format(k, v) for k, v in show_metrics.items())
        bar.next()
    bar.finish()
    return show_metrics

def OkdTest(loader,
            model,
            loss_fn,
            loss_stu,
            epoch,
            comp,
            device,
            num_branches,
            margin
            ):

    model.eval()

    accTop1_avg = list(range(num_branches+2)) # branches, ... , student
    accTop5_avg = list(range(num_branches+2)) # branches, ... , student
    t_loss_avg = list(range(num_branches+2))
    for i in range(num_branches+2):
        accTop1_avg[i] = AverageMeter()
        accTop5_avg[i] = AverageMeter()
        t_loss_avg[i] = AverageMeter()
    losses = AverageMeter() # data loss
    dist_avg = RunningMeter() # dist loss

    bar = Bar('Processing', max=len(loader))
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            if device:
                inputs, targets = inputs.cuda(device), targets.cuda(device)
            
            outputs, out_s, _ = model(inputs)
            ## Compensating posteriors and ensemble prediction
            
            route_outputs = outputs.clone().detach()
            out_t = torch.log(torch.mean(F.softmax(route_outputs-comp.expand(route_outputs.size(0), comp.size(0), -1), dim=1), dim=2))
 
            for i in range(num_branches): # All branches, masking data loss
                loss_ = loss_fn[i](outputs[:,:,i], targets)
                t_loss_avg[i].update(loss_.data.item(), inputs.size(0))

            loss_cross = loss_stu(out_s, targets) # student, original data loss
            losses.update(loss_cross.data.item(), inputs.size(0))

            for i in range(num_branches):
                metrics = accuracy(outputs[:,:,i], targets, topk=(1, 5))
                accTop1_avg[i].update(metrics[0].item(), inputs.size(0))
                accTop5_avg[i].update(metrics[1].item(), inputs.size(0))

            metrics = accuracy(out_s, targets, topk=(1, 5))
            accTop1_avg[num_branches].update(metrics[0].item(), inputs.size(0))
            accTop5_avg[num_branches].update(metrics[1].item(), inputs.size(0))

            metrics = accuracy(out_t, targets, topk=(1, 5))
            accTop1_avg[num_branches+1].update(metrics[0].item(), inputs.size(0))
            accTop5_avg[num_branches+1].update(metrics[1].item(), inputs.size(0))

            outputs = F.softmax(outputs, dim=1).permute(0,2,1)
            sim = torch.cdist(outputs, outputs, p=2)
            sim = sim.masked_select(~(torch.diag_embed(torch.diagonal(sim, offset=0, dim1=-2, dim2=-1)+1e3, offset=0, dim1=-2, dim2=-1).bool())).view(sim.size(0), num_branches, -1).mean(2).mean(1).sum()
            dist_avg.update2(sim.item(), outputs.size(0)) # averaged distance for all samples

            show_metrics = {}
            show_metrics.update({'Loss_stu': losses.avg, 'Dist': dist_avg.value()})

            for i in range(num_branches):
                show_metrics.update({'loss_C'+str(i): t_loss_avg[i].avg})
                show_metrics.update({'Top1_C'+str(i): accTop1_avg[i].avg})

            show_metrics.update({'Top1_stu': accTop1_avg[num_branches].avg})
            show_metrics.update({'Top1_t': accTop1_avg[num_branches+1].avg})
            bar.suffix = " | ".join("{}: {:.4f}".format(k, v) for k, v in show_metrics.items())
            bar.next()
        bar.finish()
    return show_metrics
