import timm
import tqdm
import torch
import torch.nn as nn
import torch.optim as optim

reweight = True # True if LiLAW is to be used, False otherwise

# setup of model and difficulty parameters
model = timm.create_model(model_name, num_classes=num_classes).to(device)
alpha = nn.Parameter(torch.tensor(init_alpha), requires_grad=True) # high value for alpha to start with, init_alpha = 10
beta = nn.Parameter(torch.tensor(init_beta), requires_grad=True) # low value for beta to start with, init_beta = 2
delta = nn.Parameter(torch.tensor(init_delta), requires_grad=True) # medium value for delta to start with, init_delta = 6

# setup of optimizer and scheduler
if reweight:
    optimizer = optim.Adam(list(model.parameters()) + list([alpha, beta, delta]), lr=lr, weight_decay=wd) # other optimizers can also be used
else:
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
scheduler = optim.lr_scheduler.LinearLR(optimizer) # other schedulers can also be used

# choice of loss function
if loss == 'CE': # cross-entropy
    criterion = WeightedCrossEntropyLoss(reweight=reweight, alpha=alpha, beta=beta, delta=delta, num_classes=num_classes, warmup=warmup, device=device, model_name=model_name)
elif loss == 'FL': # focal loss
    criterion = WeightedFocalLoss(reweight=reweight, alpha=alpha, beta=beta, delta=delta, gamma=focal_gamma, num_classes=num_classes, warmup=warmup, device=device, model_name=model_name)

for epoch in tqdm(epochs):
    # training and meta-validation
    for i, data in tqdm(enumerate(dataloader)):
        inputs, observed_label = data

        # ...

        # update model first in the training step
        model.requires_grad = True
        alpha.requires_grad = False
        beta.requires_grad = False
        delta.requires_grad = False

        # ...

        # training update each epoch
        outputs = model(inputs).float()
        correct_outputs, max_outputs, alpha_weights, beta_weights, delta_weights, weights, train_loss = criterion(outputs, observed_label, epoch=epoch)

        # ...
        
        # backward pass
        train_loss.mean().backward()

        # gradient update (model)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        # zero grads
        for param in model.parameters():
            param.grad = None
        if reweight:
            alpha.grad = None
            beta.grad = None
            delta.grad = None

        # ...

        # meta-validation
        for j, val_data in enumerate(val_dataloader):
            val_inputs, val_observed_label = val_data

            # update the difficulty parameters in the single mini-batch validation step
            model.requires_grad = False
            alpha.requires_grad = True
            beta.requires_grad = True
            delta.requires_grad = True
        
            # ...

            val_outputs = model(val_inputs).float()
            correct_outputs, max_outputs, alpha_weights, beta_weights, delta_weights, weights, val_loss = criterion(val_outputs, val_observed_label, epoch=epoch)

            # ...

            # backward pass
            val_loss.mean().backward()

            if reweight and epoch > warmup: # let model warmup for a few epochs before meta-validation, if need be
                with torch.no_grad():
                    # gradient update (LiLAW parameters)
                    alpha -= alpha_lr * (alpha.grad + alpha_wd*alpha)
                    alpha.data.clamp_(min=1.0)

                    beta -= beta_lr * (beta.grad + beta_wd*beta)

                    delta -= delta_lr * (delta.grad + delta_wd*delta)
                    delta.data.clamp_(min=beta.detach().item())

            # zero grads
            alpha.grad = None
            beta.grad = None
            delta.grad = None
            
            break # only need one validation mini-batch (note that val_dataloader shuffles data each time)

        # ...

# custom loss function for weighted cross entropy loss
class WeightedCrossEntropyLoss(nn.CrossEntropyLoss):
    def __init__(self, reweight=True, alpha=None, beta=None, delta=None, num_classes=2, warmup=0, device=None, model_name=None):
        super(nn.CrossEntropyLoss, self).__init__()
        self.reweight = reweight # whether or not to reweight
        self.alpha = alpha # alpha parameter
        self.beta = beta # beta parameter
        self.delta = delta # delta parameter
        self.warmup = warmup  # number of warmup epochs
        self.model_name = model_name # name of the model, if needed
        self.num_classes = num_classes # number of classes
        self.device = device # cpu or gpu
        self.sigmoid = nn.Sigmoid()

    @torch.compile
    def softmax(self, outputs):
        return (torch.exp(outputs.t()) / torch.sum(torch.exp(outputs), dim=1)).t()

    def encode(self, targets):
        encoded_targets = torch.zeros(targets.size(0), self.num_classes, device=self.device)
        encoded_targets.scatter_(1, targets.view(-1, 1).long(), 1).float()
        return encoded_targets

    def weights(self, correct_outputs, max_outputs):
        alpha_weights = self.sigmoid(self.alpha*correct_outputs - max_outputs)
        beta_weights = self.sigmoid(-(self.beta*correct_outputs - max_outputs))
        delta_weights = torch.exp(-(-(self.delta*correct_outputs - max_outputs))**2/2)
        weights = alpha_weights + beta_weights + delta_weights
        return alpha_weights, beta_weights, delta_weights, weights

    def forward(self, outputs, targets, epoch=-1):
        # softmax of the network logits
        softmax_outputs = self.softmax(outputs)
        # one-hot targets
        encoded_targets = self.encode(targets)
        loss = - torch.sum(torch.log(softmax_outputs) * (encoded_targets), dim=1)
        # the correct (observed) outputs
        correct_outputs = softmax_outputs.gather(1, torch.argmax(encoded_targets, dim=1).unsqueeze(1)).squeeze(1)
        # the maximum (predicted) outputs
        max_outputs = softmax_outputs.gather(1, torch.argmax(softmax_outputs, dim=1).unsqueeze(1)).squeeze(1)

        if self.reweight and epoch > self.warmup:
            alpha_weights, beta_weights, delta_weights, weights = self.weights(correct_outputs, max_outputs)
            weighted_loss = weights * loss
            return correct_outputs, max_outputs, alpha_weights, beta_weights, delta_weights, weights, weighted_loss.mean()
        else:
            return correct_outputs, max_outputs, None, None, None, None, loss.mean()

# custom loss function for weighted focal loss
class WeightedFocalLoss(nn.CrossEntropyLoss):
    def __init__(self, reweight=True, alpha=None, beta=None, delta=None, gamma=None, num_classes=2, warmup=0, device=None, model_name=None):
        super(nn.CrossEntropyLoss, self).__init__()
        self.reweight = reweight # whether or not to reweight
        self.alpha = alpha # alpha parameter
        self.beta = beta # beta parameter
        self.delta = delta # delta parameter
        self.gamma = gamma # focal loss gamma parameter
        self.warmup = warmup  # number of warmup epochs
        self.model_name = model_name # name of the model, if needed
        self.num_classes = num_classes # number of classes
        self.device = device # cpu or gpu

    def forward(self, outputs, targets, epoch=-1):
        criterion = WeightedCrossEntropyLoss(reweight=False, alpha=self.alpha, beta=self.beta, delta=self.delta, num_classes=self.num_classes, device=self.device, model_name=self.model_name)
        correct_outputs, max_outputs, alpha_weights, beta_weights, delta_weights, weights, cross_entropy_loss = criterion(outputs, targets, epoch=epoch)
        # focal loss with parameter gamma
        focal_loss = (1 - torch.exp(- cross_entropy_loss)) ** self.gamma * cross_entropy_loss
        encoded_targets = criterion.encode(targets)
        if self.reweight and epoch > self.warmup:
            alpha_weights, beta_weights, delta_weights, weights = criterion.weights(correct_outputs, max_outputs)
            weighted_focal_loss = weights * focal_loss
            return correct_outputs, max_outputs, alpha_weights, beta_weights, delta_weights, weights, weighted_focal_loss.mean()
        else:
            return correct_outputs, max_outputs, None, None, None, None, focal_loss.mean()