
import torch
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split

def validate(model, X_test, y_test, criterion):
    # Evaluation of the model
    with torch.no_grad():
        if "moe" in model.__class__.__name__.lower():
            expert_outputs, gating_weights = model(X_test)
                
            num_experts, batch_size, output_size = expert_outputs.shape
            losses = criterion(expert_outputs.flatten(), y_test.repeat(num_experts, 1).flatten()) 
            predictions = torch.einsum("be, ebo -> bo",  gating_weights, expert_outputs)
            final_loss = torch.einsum("be, eb -> b",  gating_weights, losses.reshape(num_experts, batch_size))
            
        else:
            predictions = model(X_test)
            final_loss = criterion(predictions, y_test)
        # Aggregate the losses
        final_loss = torch.mean(final_loss)
        rounded_predictions = torch.round(predictions)

        # Compute the accuracy by using n-out-of-n mechanism
        accuracy = (rounded_predictions.flatten() == y_test.flatten()).sum().item() / len(y_test)
   
        return final_loss, accuracy

        


def train(model, X_train, y_train, X_test, y_test, criterion, optimizer,  scheduler=None, batch_size=16, num_epochs=10):
        train = torch.utils.data.TensorDataset(X_train,y_train)
        dataloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
        train_losses_list = []
        test_losses_list = []
        train_acc = 0.0
        test_acc = 0.0
        m = len(X_train)
        # Training loop

        for epoch in range(num_epochs):
            model.train()
            epoch_loss = 0.0
            for X, y in dataloader:
                optimizer.zero_grad()
                expert_outputs, gating_weights = model(X)
                num_experts, batch_size, output_size = expert_outputs.shape
                # Compute the loss of each expert and each input and aggregate the results
                losses = criterion(expert_outputs.reshape(num_experts*batch_size, output_size).flatten(), y.repeat(num_experts, 1).flatten()) 
                final_loss = torch.einsum("be, eb -> b",  gating_weights, losses.reshape(num_experts, batch_size))
                
            
                # Aggregate the losses
                final_loss = batch_size*torch.mean(final_loss)

                final_loss.backward()
                optimizer.step()
                epoch_loss+= final_loss.item()
                if not scheduler is None:
                    scheduler.step()
                    

        # Evaluation the model 
        model.eval()
        with torch.no_grad():  
            final_loss, accuracy = validate(model, X_test, y_test, criterion)
            test_losses_list.append(final_loss.item())
            test_acc+= accuracy
            final_loss, accuracy = validate(model, X_train, y_train, criterion)
            train_losses_list.append(final_loss.item())
            train_acc += accuracy
        return train_losses_list, test_losses_list, train_acc, test_acc


class ProbitLoss(nn.Module):
    def __init__(self, reduction='None'):
        super(ProbitLoss, self).__init__()
        self.reduction=reduction

    def forward(self, predictions, targets):
        ywx = torch.einsum("b, b -> b",  predictions, targets) 
        loss = 0.5 - 0.5*torch.erf(ywx*(1/np.sqrt(2)))
        if self.reduction == 'mean':
            loss = torch.mean(loss)
        elif self.reduction=='sum':
            loss = torch.sum(loss)
        return loss

    




