import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from torchmetrics.regression import MeanAbsoluteError


def encode_train_set(clftrainloader, device, net):
    net.eval()
    store = []
    with torch.no_grad():
        t = tqdm(enumerate(clftrainloader), desc='Encoded: **/** ', total=len(clftrainloader),
                 bar_format='{desc}{bar}{r_bar}')
        for batch_idx, (inputs, targets) in t:
            inputs, targets = inputs.to(device), targets.to(device)
            representation = net(inputs)
            store.append((representation, targets))

            t.set_description('Encoded %d/%d' % (batch_idx, len(clftrainloader)))

    X, y = zip(*store)
    X, y = torch.cat(X, dim=0), torch.cat(y, dim=0)
    return X, y

def encode_conditional_train_set(clf_conditional_trainloader, device, net, CelebA_UTKFace):
    net.eval()
    store = []
    with torch.no_grad():
        t = tqdm(enumerate(clf_conditional_trainloader), desc='Encoded: **/** ', total=len(clf_conditional_trainloader),
                 bar_format='{desc}{bar}{r_bar}')
        for batch_idx, (inputs, _, conditions) in t:
            if CelebA_UTKFace == 'UTKFace':
                #conditions, _ = conditions # select out the embeddings
                _, conditions = conditions # select out the original protected attribute
            inputs, conditions = inputs.to(device), conditions.to(device)
            representation = net(inputs)
            store.append((representation, conditions))
            t.set_description('Encoded %d/%d' % (batch_idx, len(clf_conditional_trainloader)))

    X, y = zip(*store)
    X, y = torch.cat(X, dim=0), torch.cat(y, dim=0)
    return X, y

def train_clf(X, y, representation_dim, num_classes, device, reg_weight=1e-3, optim_choice='lbfgs', continuous=False, CelebA_UTKFace = None):
    print('\nL2 Regularization weight: %g' % reg_weight)
    print(f"optim_choice {optim_choice}")

    if continuous and CelebA_UTKFace == None:
        criterion = nn.MSELoss()

    elif continuous and CelebA_UTKFace == 'UTKFace':
        # when continuous = True with UTKFace, mixed output problem, need both regress and cross entropy loss
        criterion = nn.MSELoss()
        class_criterion = nn.BCEWithLogitsLoss()
        #criterion = nn.L1Loss() # regression task on 16-dim continuous vector with all values between 0 and 1
    
    elif not continuous and CelebA_UTKFace == 'CelebA': 
        criterion = nn.BCEWithLogitsLoss()
        mae = MeanAbsoluteError()

    else:
        criterion = nn.CrossEntropyLoss()



    # Should be reset after each epoch for a completely independent evaluation
    clf = nn.Linear(representation_dim, num_classes).to(device)
    if optim_choice == 'lbfgs':
        n_optim_steps = 250 * 2
        clf_optimizer = optim.LBFGS(clf.parameters())
    elif optim_choice == 'adam':
        n_optim_steps = 250 * 20
        clf_optimizer = optim.Adam(clf.parameters(), lr=1e-3, weight_decay=reg_weight)
    else:
        raise NotImplementedError

    clf.train()

    t = tqdm(range(n_optim_steps), desc='Loss: **** | Train Acc: ****% ', bar_format='{desc}{bar}{r_bar}')
    if continuous and CelebA_UTKFace == 'CelebA':
        num_clf_epochs = 3 # trying training the classifier for 3 epochs on the condition
    else:
        num_clf_epochs = 1
    for epoch in range(num_clf_epochs):
        for _ in t:
            def closure():
                clf_optimizer.zero_grad()

                raw_scores = clf(X)
                if CelebA_UTKFace == 'CelebA': # continuous or not, CelebA trains categorical
                    loss = criterion(raw_scores.squeeze(), y.float())
                    loss_mse = loss
                    loss += reg_weight * clf.weight.pow(2).sum()
                    loss.backward()
                

                if continuous and CelebA_UTKFace == 'UTKFace': # fairness evaluation, train on the embeddings
                    reg_loss = criterion(raw_scores[:,0], y[:,0])
                    class_loss = class_criterion(raw_scores[:,1], y[:,1])
                    loss =  0.1*reg_loss + 0.9*class_loss
                    #loss = criterion(raw_scores, y)
                    loss_mse = loss
                    loss += reg_weight * clf.weight.pow(2).sum()
                    loss.backward()
                    

                if not continuous and CelebA_UTKFace == 'UTKFace': # train on ethnicity, regular categorical
                    loss = criterion(raw_scores, y)
                    loss += reg_weight * clf.weight.pow(2).sum()
                    loss.backward()

                if CelebA_UTKFace == None: # colorMNIST 
                    loss = criterion(raw_scores, y)
                    if continuous:
                        loss_mse = loss
                    
                    loss += reg_weight * clf.weight.pow(2).sum()
                    loss.backward()



                if not continuous:
                    if CelebA_UTKFace == 'CelebA':
                        predicted = (raw_scores > 0).int().squeeze() # binary logits positive prediction are just positives
                    else:
                        _, predicted = raw_scores.max(1)

                    correct = predicted.eq(y).sum().item()

                    t.set_description('Loss: %.3f | Train Acc: %.3f%% ' % (loss, 100. * correct / y.shape[0]))
                else:
                    t.set_description('Loss: %.10f%% ' % loss_mse)

                return loss

            clf_optimizer.step(closure)

    return clf


def test(testloader, device, net, clf, CelebA_UTKFace):
    if CelebA_UTKFace == 'CelebA':
        criterion = nn.BCEWithLogitsLoss()
        mae = MeanAbsoluteError().to(device)
    else:
        criterion = nn.CrossEntropyLoss()
    net.eval()
    clf.eval()
    test_clf_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        t = tqdm(enumerate(testloader), total=len(testloader), desc='Loss: **** | Test Acc: ****% ',
                 bar_format='{desc}{bar}{r_bar}')
        for batch_idx, (inputs, targets) in t:
            inputs, targets = inputs.to(device), targets.to(device)
            representation = net(inputs)
            # test_repr_loss = criterion(representation, targets)
            raw_scores = clf(representation)
            if CelebA_UTKFace == 'CelebA':
                clf_loss = criterion(raw_scores.squeeze(), targets.float())
            else:
                clf_loss = criterion(raw_scores, targets)

            test_clf_loss += clf_loss.item()
            if CelebA_UTKFace == 'CelebA':
                predicted = (raw_scores > 0).int().squeeze()
            
            else:
                _, predicted = raw_scores.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        
            t.set_description('Loss: %.3f | Test Acc: %.3f%% ' % (test_clf_loss / (batch_idx + 1), 100. * correct / total))

    acc = 100. * correct / total

    return acc



def test_conditional(testloader, device, net, clf, CelebA_UTKFace):
    mae = MeanAbsoluteError().to(device)
    sigmoid = nn.Sigmoid().to(device)
    test_mae_loss = 0
    if CelebA_UTKFace == 'CelebA':
        criterion = nn.BCEWithLogitsLoss()
        criterion2 = nn.MSELoss()
        test_mse_loss = 0
    else:
        criterion = nn.MSELoss()
    net.eval()
    clf.eval()
    test_clf_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        t = tqdm(enumerate(testloader), total=len(testloader), desc='Loss: **** | Test Acc: ****% ',
                 bar_format='{desc}{bar}{r_bar}')
        for batch_idx, (inputs, _, condition) in t:
            if CelebA_UTKFace == 'UTKFace':
                #condition, _ = condition # select embedded
                #condition = condition*100 # multiply up for interpretability
                _, condition = condition # select original protected attribute
                condition[:,1] = condition[:,1]*100 # multiply up categorical to 0-100

            inputs, condition = inputs.to(device), condition.to(device)
            representation = net(inputs)
            # test_repr_loss = criterion(representation, targets)
            raw_scores = clf(representation)
            if CelebA_UTKFace == 'UTKFace':
                raw_scores[:,1] = sigmoid(raw_scores[:,1])*100 # convert to probs
                #raw_scores = raw_scores*100 # multiply up for better interpretability
            if CelebA_UTKFace == 'CelebA':
                condition = condition.float()
                probs = F.sigmoid(raw_scores).squeeze()
                mae_loss = mae(probs, condition).item()
                mse_loss = criterion2(probs*100, condition*100).item()

                test_mae_loss += mae_loss
                test_mse_loss += mse_loss
            
            clf_loss = criterion(raw_scores, condition)

            test_clf_loss += clf_loss.item()

            if CelebA_UTKFace == 'CelebA':
                t.set_description('Loss: %.10f%% | MSE: %.3f%% ' % (test_clf_loss / (batch_idx + 1), test_mse_loss/ (batch_idx+1)))
            
            else:
                t.set_description('Loss: %.10f%% ' % (test_clf_loss / (batch_idx + 1)))

    if CelebA_UTKFace == 'CelebA':
        return test_clf_loss / (batch_idx + 1), test_mse_loss / (batch_idx + 1), test_mae_loss*100 / (batch_idx+1)
    else:
        return test_clf_loss / (batch_idx + 1)
