import torch
from torch import nn
from utils import *
from tensors import *

### Models
def Regularizer(model, loss, weight_decay):
    return loss + 0.5 * weight_decay * sum(p.norm(2)**2 for p in model.parameters())

class LogReg(nn.Module):
    def __init__(self, input_size, num_classes):
        super().__init__()
        self.layers = nn.Sequential(
          nn.Linear(input_size, num_classes-1)
        )

    def forward(self, x):
        device = x.device.type
        logits = self.layers(x)
        logits = torch.hstack((torch.ones((logits.shape[0],1)).to(device), logits))
        return nn.functional.softmax(logits, dim=1)

### Training
def CIRisk(loss_tensor, Z, set_Z, P_Y_Z, device):
    p = OneHotEncode(Z, set_Z, device)@P_Y_Z.T
    return (loss_tensor.squeeze() * p).sum(dim=1).mean()

def TrainModelCI(model, X, Z, set_Z, P_Y_Z, weight_decay, tol, max_epochs, device):
 
    ### Create optimizer ###
    optimizer = torch.optim.LBFGS(model.parameters(), line_search_fn='strong_wolfe')

    ### Defining closure ###
    def closure():
        if torch.is_grad_enabled():
            optimizer.zero_grad() 
        loss_tensor = GetLogLossTensor(model, X)
        loss = CIRisk(loss_tensor, Z, set_Z, P_Y_Z, device)
        loss = Regularizer(model, loss, weight_decay)
        if loss.requires_grad:
            loss.backward()
        return loss
        
    ### Run the training loop ###
    hist = []
    for epoch in range(int(max_epochs)): 
        loss = optimizer.step(closure)
        model_norm = GetGradNorm(model)
        if model_norm<tol:
            break
            
    ### Output ###
    return model
