import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from datasets.utils import idx_to_mask

def train(model, train_idx, labels, device, optimizer, loss_fn):
    model.train()
    optimizer.zero_grad()

    train_output = model.model_forward(train_idx, device)
    loss_train = loss_fn(train_output, labels[train_idx])
    acc_train = accuracy(train_output, labels[train_idx])
    loss_train.backward()
    optimizer.step()

    return loss_train.item(), acc_train

def kd_el_train(model, model_list, train_idx, labels, device, optimizer, loss_ce, loss_kd,
                temperature = 1.5, entropy_thre=5):
    num_knowledge_model = len(model_list)
    
    if num_knowledge_model != 0:
        model.train()
        optimizer.zero_grad()

        train_output = model.model_forward(train_idx, device)
        train_output_knowledge_list = []
        
        for i in range(num_knowledge_model):
            train_output_knowledge = model_list[i].model_forward(train_idx, device)
            train_output_knowledge_list.append(train_output_knowledge)

        z_ensemble = F.softmax(train_output_knowledge_list[0].data / temperature, 1)
        for i in range(1, num_knowledge_model):
            z_ensemble += F.softmax(train_output_knowledge_list[i].data / temperature, 1)
        z_ensemble /= num_knowledge_model
        z_ensemble_entropy_idx = torch.where(Categorical(z_ensemble).entropy() > entropy_thre)[0]
        z_ensemble_entropy_mask = idx_to_mask(z_ensemble_entropy_idx, train_output.shape[0]).logical_not()
        
        loss1 = loss_ce(train_output, labels[train_idx])
        loss2 = loss_kd(F.log_softmax(train_output, dim=1)[z_ensemble_entropy_mask], z_ensemble[z_ensemble_entropy_mask])
        loss_train = loss1 + 0.01*loss2

        acc_train = accuracy(train_output, labels[train_idx])
        loss_train.backward()
        optimizer.step()      
    else:
        model.train()
        optimizer.zero_grad()

        train_output = model.model_forward(train_idx, device)
        loss_train = loss_ce(train_output, labels[train_idx])
        acc_train = accuracy(train_output, labels[train_idx])
        loss_train.backward()
        optimizer.step()

    return loss_train.item(), acc_train


def evaluate(model, val_idx, test_idx, labels, device):
    model.eval()
    output = model.model_forward(range(len(val_idx)), device)

    acc_val = accuracy(output[val_idx], labels[val_idx])
    acc_test = accuracy(output[test_idx], labels[test_idx])
    return acc_val, acc_test

def kd_el_evaluate(model_list, val_idx, test_idx, labels, device):
    num_model = len(model_list)
    output_list = []
    for i in range(num_model):
        model_list[i].eval()
        output = model_list[i].model_forward(range(len(val_idx)), device)
        output_list.append(output)

    val_z_ensemble = output_list[0][val_idx]
    test_z_ensemble = output_list[0][test_idx]
    for i in range(1, num_model):
        val_z_ensemble += output_list[i][val_idx]
        test_z_ensemble += output_list[i][test_idx]
    val_z_ensemble /= num_model
    test_z_ensemble /= num_model

    acc_val = accuracy(val_z_ensemble, labels[val_idx])
    acc_test = accuracy(test_z_ensemble, labels[test_idx])
    return acc_val, acc_test
    
def accuracy(output, labels):
    pred = output.max(1)[1].type_as(labels)
    correct = pred.eq(labels).double()
    correct = correct.sum()
    return (correct / len(labels)).item()
