"""
Training, evaluating, calculating embeddings functions
"""
import os
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from network import get_criterion, get_optim
from network import save_checkpoint, get_output, get_bert_scheduler
from utils import print_header
from utils.logging import summarize_acc
from utils.metrics import compute_roc_auc
from activations import compute_activation_mi, save_activations, compute_align_loss


def train_model(net, optimizer, criterion, train_loader, val_loader,
                args, start_epoch=0, epochs=None, log_test_results=False,
                test_loader=None, test_criterion=None,
                checkpoint_interval=None, scheduler=None):
    """
    Train model for specified number of epochs

    Args:
    - net (torch.nn.Module): Pytorch model network
    - optimizer (torch.optim): Model optimizer
    - criterion (torch.nn.Criterion): Pytorch loss function
    - train_loader (torch.utils.data.DataLoader): Training dataloader
    - val_loader (torch.utils.data.DataLoader): Validation dataloader
    - args (argparse): Experiment args
    - start_epoch (int): Which epoch to start from
    - epochs (int): Number of epochs to train
    - log_test_results (bool): If true evaluate model on test set after each epoch and save results
    - test_loader (torch.utils.data.DataLoader): Testing dataloader
    - test_criterion (torch.nn.Criterion): Pytorch testing loss function, most likely has reduction='none'
    - scheduler (torch.optim.lr_scheduler): Learning rate scheduler
    """
    try:
        if args.load_encoder is True or args.train_encoder is True:
            net.eval()
        else:
            net.train()
    except:
        net.train()
        
    net.train()
    max_robust_test_acc = 0
    max_robust_epoch = None
    max_robust_test_group_acc = None
    all_acc = []
    args.max_robust_acc_v = 0
    args.max_avg_worst_gap_v = 0

    epochs = args.max_epoch if epochs is None else epochs
    net.to(args.device)
    for epoch in range(start_epoch, start_epoch + epochs):
        scheduler_ = None
        if args.optim == 'AdamW' and args.dataset == 'civilcomments':
            total_updates = int(np.round(
                len(train_loader) * (epochs - start_epoch)))
            last_epoch = int(np.round(epoch * len(train_loader)))
            scheduler_ = get_bert_scheduler(optimizer, n_epochs=total_updates,
                                               warmup_steps=args.warmup_steps,
                                               dataloader=train_loader,
                                               last_epoch=last_epoch)
        train_outputs = train(net, train_loader, optimizer, criterion, args, scheduler_)
        running_loss, correct, total, correct_by_groups, total_by_groups = train_outputs
        
        if checkpoint_interval is not None and (epoch + 1) % checkpoint_interval == 0:
            save_checkpoint(net, optimizer, running_loss,
                            epoch, batch=0, args=args,
                            replace=True, retrain_epoch=None)
        
        val_outputs = evaluate(net, val_loader, test_criterion, args, testing=True)
        val_running_loss, val_correct, val_total, correct_by_groups_v, total_by_groups_v, correct_indices = val_outputs
        if (epoch + 1) % args.log_interval == 0:
            print(f'Epoch: {epoch + 1:3d} | Train Loss: {running_loss / total:<.3f} | Train Acc: {100 * correct / total:<.3f} | Val Loss: {val_running_loss / val_total:<.3f} | Val Acc: {100 * val_correct / val_total:<.3f}')
            
        if args.verbose is True:
            print('Training:')
            summarize_acc(correct_by_groups, total_by_groups)
            
        if args.verbose is True:
            print('Validating:')
            average_acc_v, robust_acc_v = summarize_acc(correct_by_groups_v, total_by_groups_v)
            avg_worst_gap_v = average_acc_v - robust_acc_v
            
            if args.model_name_ == 'bias':
                print(f'Save biased model at epoch {epoch}')
                checkpoint_name = save_checkpoint(net, None,
                                                avg_worst_gap_v,  # override loss
                                                epoch, -1, args,
                                                replace=True,
                                                retrain_epoch=-1,
                                                identifier=f'{args.exp}_model_b_epoch{epoch}')
                
            if (avg_worst_gap_v > args.max_avg_worst_gap_v) and (args.model_name_ == 'bias'):
                print(f'New max average-worst acc gap: {avg_worst_gap_v}')
                args.max_avg_worst_gap_v = avg_worst_gap_v
                args.max_avg_worst_gap_epoch_v = epoch
                args.max_avg_worst_gap_group_v = (correct_by_groups_v, total_by_groups_v)

                print(f'{args.model_name_} model - Saving best checkpoint at epoch {epoch}')
                checkpoint_name = save_checkpoint(net, None,
                                                avg_worst_gap_v,  # override loss
                                                epoch, -1, args,
                                                replace=True,
                                                retrain_epoch=-1,
                                                identifier=f'{args.exp}_model_b_worst_avg_gap_best_epoch{epoch}')
                args.checkpoint_name = checkpoint_name

            elif (robust_acc_v > args.max_robust_acc_v) and (args.model_name_ == 'debias'):
                print(f'New max robust acc: {robust_acc_v}')
                args.max_robust_acc_v = robust_acc_v
                args.max_robust_epoch_v = epoch
                args.max_robust_group_acc_v = (correct_by_groups_v, total_by_groups_v)

                print(f'{args.model_name_} model - Saving best checkpoint at epoch {epoch}')
                checkpoint_name = save_checkpoint(net, None,
                                                robust_acc_v,  # override loss
                                                epoch, -1, args,
                                                replace=True,
                                                retrain_epoch=-1,
                                                identifier=f'{args.model_name_}-wga-best')
                args.checkpoint_name = checkpoint_name
            

        if args.optim == 'sgd' and scheduler is not None:
            group_acc = []
            for yix, y_group in enumerate(correct_by_groups_v):
                y_correct = []
                y_total = []
                for aix, a_group in enumerate(y_group):
                    if total_by_groups_v[yix][aix] > 0:
                        acc = a_group / total_by_groups_v[yix][aix]
                        y_correct.append(a_group)
                        y_total.append(total_by_groups_v[yix][aix])
                group_acc.append(np.sum(y_correct) /
                                 np.sum(y_total))
            group_avg_acc = np.mean(group_acc)
            print(group_acc)
            print(group_avg_acc)
            scheduler.step(group_avg_acc)
        
        if args.dataset == 'celebA' and scheduler is not None:
            scheduler.step()
            
        if log_test_results:
            assert test_loader is not None
            test_outputs = test_model(net, test_loader, test_criterion, args, epoch, mode='Training')
            test_running_loss, test_correct, test_total, correct_by_groups_t, total_by_groups_t, correct_indices, all_losses, losses_by_groups = test_outputs
            
            print('Testing:')
            average_test_acc, robust_test_acc = summarize_acc(correct_by_groups_t,
                                                              total_by_groups_t)
            all_acc.append(robust_test_acc)
            if robust_test_acc >= max_robust_test_acc:
                max_robust_test_acc = robust_test_acc
                args.max_robust_acc = max_robust_test_acc
                max_robust_epoch = epoch
                max_robust_test_group_acc = (correct_by_groups_t,
                                             total_by_groups_t)
                
            plt.plot(all_acc)
            plt.title(f'Worst-group test accuracy (max acc: {args.max_robust_acc:<.4f})')
            figpath = os.path.join(args.results_path, f'ta-{args.experiment_name}.png')
            plt.savefig(figpath)
            plt.close()
            
            max_robust_metrics = (max_robust_test_acc, max_robust_epoch,
                                  max_robust_test_group_acc)
            if epoch + 1 == start_epoch + epochs:
                # save last epoch model
                checkpoint_name = save_checkpoint(net, None,
                                                robust_acc_v,  # override loss
                                                epoch, -1, args,
                                                replace=True,
                                                retrain_epoch=-1,
                                                identifier=f'{args.model_name_}-end')
                
                return net, max_robust_metrics, all_acc

    return (val_running_loss, val_correct, val_total, correct_by_groups, total_by_groups, correct_indices)


def test_model(net, test_loader, criterion, args, epoch, mode='Testing'):
    net.eval()
    # test_running_loss, test_correct, test_total, correct_by_groups, total_by_groups, correct_indices, all_losses, losses_by_groups = evaluate(
    #     net, test_loader, criterion, args, testing=True, return_losses=True)
    test_running_loss, test_correct, test_total, correct_by_groups, total_by_groups, correct_indices = evaluate(
        net, test_loader, criterion, args, testing=True, return_losses=True)
    acc_by_groups = correct_by_groups / total_by_groups
    if args.dataset != 'civilcomments':
        loss_header_1 = f'Avg Test Loss: {test_running_loss / test_total:<.3f} | Avg Test Acc: {100 * test_correct / test_total:<.3f}'
        # loss_header_2 = f'Robust Loss: {np.max(losses_by_groups):<.3f} | Best Loss: {np.min(losses_by_groups):<.3f}'
        print_header(loss_header_1, style='top')
        # print(loss_header_2)
    loss_header_3 = f'Robust Acc: {100 * np.min(acc_by_groups):<.3f} | Best Acc: {100 * np.max(acc_by_groups):<.3f}'
    
    print_header(loss_header_3, style='bottom')
    print(f'{mode}, Epoch {epoch}:')
    min_acc = summarize_acc(correct_by_groups, total_by_groups)
    
    if mode == 'Testing':
        if min_acc > args.max_robust_acc:
            max_robust_acc = min_acc  # Outsourced this
        else:
            max_robust_acc = args.max_robust_acc

        # Compute MI of activations
        attributes = ['target']
        if args.dataset != 'civilcomments':
            attributes.append('spurious')
        
        attribute_names = []
        
        embeddings, _ = save_activations(net, test_loader, args)
        mi_attributes = compute_activation_mi(attributes, test_loader, 
                                              method='logistic_regression',
                                              classifier_test_size=0.5,
                                              max_iter=5000,
                                              model=net,
                                              embeddings=embeddings, 
                                              seed=args.seed, args=args)
        for ix, attribute in enumerate(attributes):
            name = f'embedding_mutual_info_{attribute}'
            if name not in args.test_metrics:
                args.test_metrics[name] = []
            attribute_names.append(name)
            
        # Compute Loss Align
        if args.dataset in ['waterbirds', 'colored_mnist']:
            align_loss_metric_values = []
            align_loss_metrics = ['target', 'spurious']
            for align_loss_metric in align_loss_metrics:
                align_loss = compute_align_loss(embeddings, test_loader,
                                                measure_by=align_loss_metric,
                                                norm=True)
                align_loss_metric_values.append(align_loss)
                if f'loss_align_{align_loss_metric}' not in args.test_metrics:
                    args.test_metrics[f'loss_align_{align_loss_metric}'] = []

        for yix, y_group in enumerate(correct_by_groups):
            for aix, a_group in enumerate(y_group):
                args.test_metrics['epoch'].append(epoch + 1)
                args.test_metrics['target'].append(yix)  # (y_group)
                args.test_metrics['spurious'].append(aix)  # (a_group)
                args.test_metrics['acc'].append(acc_by_groups[yix][aix])
                try:
                    # args.test_metrics['loss'].append(losses_by_groups[yix][aix])
                    args.test_metrics['loss'].append(0)
                except:
                    args.test_metrics['loss'].append(-1)
                # Change this depending on setup
                args.test_metrics['model_type'].append(args.model_type)
                args.test_metrics['robust_acc'].append(min_acc)
                args.test_metrics['max_robust_acc'].append(max_robust_acc)

                # Mutual Info:
                for ix, name in enumerate(attribute_names):
                    args.test_metrics[name].append(mi_attributes[ix])
                    
                if args.dataset in ['waterbirds', 'colored_mnist']:
                    for alix, align_loss_metric in enumerate(align_loss_metrics):
                        args.test_metrics[f'loss_align_{align_loss_metric}'].append(align_loss_metric_values[alix])      
    else:
        summarize_acc(correct_by_groups, total_by_groups)
                
    # return (test_running_loss, test_correct, test_total, correct_by_groups, total_by_groups, correct_indices, all_losses, losses_by_groups)
    return (test_running_loss, test_correct, test_total, correct_by_groups, total_by_groups, correct_indices, 0, 0)


def train(net, dataloader, optimizer, criterion, args, scheduler=None):
    running_loss = 0.0
    correct = 0
    total = 0
    
    targets_s = dataloader.dataset.targets_all['spurious']
    targets_t = dataloader.dataset.targets_all['target']

    correct_by_groups = np.zeros([len(np.unique(targets_t)),
                                  len(np.unique(targets_s))])
    total_by_groups = np.zeros(correct_by_groups.shape)
    losses_by_groups = np.zeros(correct_by_groups.shape)
    
    if 'groupDRO' in args.exp:
        group_list = np.array(list(zip(targets_t.tolist(), targets_s.tolist())))
        group_num = torch.zeros([len(np.unique(targets_t)),
                             len(np.unique(targets_s))]).to(args.device)
        for y in range(group_num.shape[0]):
            for s in range(group_num.shape[1]):
                group_num[y][s] = (np.sum(group_list == np.array([int(y), int(s)]), axis=-1)==2).sum()
        group_adj = args.group_adj / torch.sqrt(group_num)

    elif 'ours-with-bias-label' in args.exp:
        group_list = np.array(list(zip(targets_t.tolist(), targets_s.tolist())))
        group_num = torch.zeros([len(np.unique(targets_t)),
                             len(np.unique(targets_s))]).to(args.device)
        for y in range(group_num.shape[0]):
            for s in range(group_num.shape[1]):
                group_num[y][s] = (np.sum(group_list == np.array([int(y), int(s)]), axis=-1)==2).sum()
        
        aligned_num = 0
        conflicting_num = 0
        if args.bias_conflicting_criterion == 0:
            # (aligned, conflicting) -> (y=bias attribute, y!=bias attribute)
            for y_ in range(group_num.shape[0]):
                for s_ in range(group_num.shape[1]):
                    if y_ == s_:
                        aligned_num += group_num[y_][s_]
                    else:
                        conflicting_num += group_num[y_][s_]
        
        elif args.bias_conflicting_criterion == 1:
            # (aligned, conflicting) -> (y=0, y=1)
            aligned_num = torch.sum(group_num, dim=1)[0]
            conflicting_num = torch.sum(group_num, dim=1)[1]

        elif args.bias_conflicting_criterion == 2:
            # (aligned, conflicting) -> ((y, s)!=(1, 1) , (y, s)==(1, 1))
            for y_ in range(group_num.shape[0]):
                for s_ in range(group_num.shape[1]):
                    if y_ == 1 and s_ == 1:
                        conflicting_num += group_num[y_][s_]
                    else:
                        aligned_num += group_num[y_][s_]
        
        elif args.bias_conflicting_criterion == 3:
            # (aligned, conflicting) -> ((y, s)!=(1, 0) , (y, s)==(1, 0))
            for y_ in range(group_num.shape[0]):
                for s_ in range(group_num.shape[1]):
                    if y_ == 1 and s_ == 0:
                        conflicting_num += group_num[y_][s_]
                    else:
                        aligned_num += group_num[y_][s_]

        aligned_adj = args.group_adj / torch.sqrt(aligned_num)
        conflicting_adj = args.group_adj / torch.sqrt(conflicting_num)

    elif 'ours-with-adj' in args.exp:
        assert dataloader.dataset.bias_aligned_conflicting_label is not None
        bias_conflicting_label = dataloader.dataset.bias_aligned_conflicting_label
        conflicting_num = np.sum(bias_conflicting_label == 1)
        aligned_num = np.sum(bias_conflicting_label == 0)

        aligned_adj = args.group_adj / np.sqrt(aligned_num)
        conflicting_adj = args.group_adj / np.sqrt(conflicting_num)
        

    net.train()
    net.zero_grad()
    
    for i, data in enumerate(dataloader):
        batch_total_by_groups = torch.zeros(correct_by_groups.shape).to(args.device)
        batch_losses_by_groups = torch.zeros(correct_by_groups.shape).to(args.device)
        if args.batch_factor is None:
            inputs, labels, data_ix = data
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            
            # print(data_ix[0], data_ix[-1])
            labels_spurious = [targets_s[ix] for ix in data_ix]
            if 'ours-with-adj' in args.exp:
                conflicting_label = [bias_conflicting_label[ix] for ix in data_ix]

            # Add this here to generalize NLP, CV models
            outputs = get_output(net, inputs, labels, args)
            loss = criterion(outputs, labels)
            if 'groupDRO' in args.exp:
                for ix, s in enumerate(labels_spurious):
                    y = labels[ix]
                    batch_total_by_groups[int(y)][int(s)] += 1
                    batch_losses_by_groups[int(y)][int(s)] += loss[ix]
                batch_total_by_groups[batch_total_by_groups==0] = 10000
                batch_loss_mean_by_groups = batch_losses_by_groups / batch_total_by_groups
                loss = torch.max(batch_loss_mean_by_groups + group_adj)
            
            elif 'ours-with-bias-label' in args.exp:
                for ix, s in enumerate(labels_spurious):
                    y = labels[ix]
                    batch_total_by_groups[int(y)][int(s)] += 1
                    batch_losses_by_groups[int(y)][int(s)] += loss[ix]
                # batch_total_by_groups[batch_total_by_groups==0] = 10000

                bias_aligned_losses = 0
                bias_aligned_count = 0
                bias_conflicting_losses = 0
                bias_conflicting_count = 0
                if args.bias_conflicting_criterion == 0:
                    # (aligned, conflicting) -> (y=bias attribute, y!=bias attribute)
                    for y_ in range(batch_losses_by_groups.shape[0]):
                        for s_ in range(batch_losses_by_groups.shape[1]):
                            if y_ == s_:
                                bias_aligned_losses += batch_losses_by_groups[y_][s_]
                                bias_aligned_count += batch_total_by_groups[y_][s_]
                            else:
                                bias_conflicting_losses += batch_losses_by_groups[y_][s_]
                                bias_conflicting_count += batch_total_by_groups[y_][s_]

                elif args.bias_conflicting_criterion == 1:
                    # # (aligned, conflicting) -> (y=0, y=1)
                    bias_aligned_losses = torch.sum(batch_losses_by_groups, dim=1)[0]
                    bias_conflicting_losses = torch.sum(batch_losses_by_groups, dim=1)[1]
                    bias_aligned_count = torch.sum(batch_total_by_groups, dim=1)[0]
                    bias_conflicting_count = torch.sum(batch_total_by_groups, dim=1)[1]
                    
                elif args.bias_conflicting_criterion == 2:
                    # (aligned, conflicting) -> ((y, s)!=(1, 1) , (y, s)==(1, 1))
                    for y_ in range(batch_losses_by_groups.shape[0]):
                        for s_ in range(batch_losses_by_groups.shape[1]):
                            if y_ == 1 and s_ == 1:
                                bias_conflicting_losses += batch_losses_by_groups[y_][s_]
                                bias_conflicting_count += batch_total_by_groups[y_][s_]
                            else:
                                bias_aligned_losses += batch_losses_by_groups[y_][s_]
                                bias_aligned_count += batch_total_by_groups[y_][s_]

                elif args.bias_conflicting_criterion == 3:
                    # (aligned, conflicting) -> ((y, s)!=(1, 0) , (y, s)==(1, 0))
                    for y_ in range(batch_losses_by_groups.shape[0]):
                        for s_ in range(batch_losses_by_groups.shape[1]):
                            if y_ == 1 and s_ == 0:
                                bias_conflicting_losses += batch_losses_by_groups[y_][s_]
                                bias_conflicting_count += batch_total_by_groups[y_][s_]
                            else:
                                bias_aligned_losses += batch_losses_by_groups[y_][s_]
                                bias_aligned_count += batch_total_by_groups[y_][s_]

                
                aligned_count = 10000 if bias_aligned_count == 0 else bias_aligned_count
                conflicting_count = 10000 if bias_conflicting_count == 0 else bias_conflicting_count

                bias_aligned_loss_mean = bias_aligned_losses / aligned_count
                bias_conflicting_loss_mean = bias_conflicting_losses / conflicting_count

                loss = max(bias_aligned_loss_mean + aligned_adj, bias_conflicting_loss_mean + conflicting_adj)
            
            elif 'ours-with-adj' in args.exp:
                bias_aligned_losses = 0
                bias_aligned_count = 0
                bias_conflicting_losses = 0
                bias_conflicting_count = 0
                for ix, s in enumerate(conflicting_label):
                    if s == 0:
                        bias_aligned_losses += loss[ix]
                        bias_aligned_count += 1
                    elif s == 1:
                        bias_conflicting_losses += loss[ix]
                        bias_conflicting_count += 1

                aligned_count = 10000 if bias_aligned_count == 0 else bias_aligned_count
                conflicting_count = 10000 if bias_conflicting_count == 0 else bias_conflicting_count

                bias_aligned_loss_mean = bias_aligned_losses / aligned_count
                bias_conflicting_loss_mean = bias_conflicting_losses / conflicting_count

                loss = max(bias_aligned_loss_mean + aligned_adj, bias_conflicting_loss_mean + conflicting_adj)

            else:
                loss = loss.mean()
            
            if args.arch == 'bert-base-uncased_pt' and args.optim == 'AdamW':
                loss.backward()
                # Toggle this?
                if args.clip_grad_norm:
                    torch.nn.utils.clip_grad_norm_(net.parameters(),
                                                args.max_grad_norm)
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()
                # optimizer.step()
                net.zero_grad()
            elif scheduler is not None:
                loss.backward()
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()
                net.zero_grad()
            else:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
            # Save performance
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            all_correct = (predicted == labels).detach().cpu()
            correct += all_correct.sum().item()
            running_loss += loss.item()
            
            # Save group-wise accuracy
            labels_target = labels.detach().cpu().numpy()
            for ix, s in enumerate(labels_spurious):
                y = labels_target[ix]
                correct_by_groups[int(y)][int(s)] += all_correct[ix].item()
                total_by_groups[int(y)][int(s)] += 1

            # Clear memory
            inputs = inputs.to(torch.device('cpu'))
            labels = labels.to(torch.device('cpu'))  
            outputs = outputs.to(torch.device('cpu'))
            loss = loss.to(torch.device('cpu'))
            del outputs; del inputs; del labels; del loss

        else:
            num_data = data[0].shape[0]
            iters = int(num_data / args.bs_trn) if num_data % args.bs_trn == 0 else int(num_data / args.bs_trn) + 1
            start, end = 0, 0
            if args.arch == 'bert-base-uncased_pt' and args.optim == 'AdamW':
                net.zero_grad()
            elif scheduler is not None:
                net.zero_grad()
            else:
                optimizer.zero_grad()

            for it in range(iters):
                end = int(min(end + args.bs_trn, num_data))
                inputs, labels, data_idx = data
                inputs_ = inputs[torch.arange(start, end)].to(args.device)
                labels_ = labels[torch.arange(start, end)].to(args.device)
                data_idx_ = data_idx[torch.arange(start, end)]

                labels_spurious_ = [targets_s[ix] for ix in data_idx_]

                outputs_ = get_output(net, inputs_, labels_, args)
                loss_ = criterion(outputs_, labels_)
                loss_ = loss_.sum() / num_data
                loss_.backward()

                start = end

                # Save performance
                _, predicted_ = torch.max(outputs_.data, 1)
                total += labels_.size(0)
                all_correct_ = (predicted_ == labels_).detach().cpu()
                correct += all_correct_.sum().item()
                running_loss += loss_.item()
                
                # Save group-wise accuracy
                labels_target_ = labels_.detach().cpu().numpy()
                for ix, s in enumerate(labels_spurious_):
                    y = labels_target_[ix]
                    correct_by_groups[int(y)][int(s)] += all_correct_[ix].item()
                    total_by_groups[int(y)][int(s)] += 1

                # Clear memory
                inputs_ = inputs_.to(torch.device('cpu'))
                labels_ = labels_.to(torch.device('cpu'))  
                outputs_ = outputs_.to(torch.device('cpu'))
                loss_ = loss_.to(torch.device('cpu'))
                del outputs_; del inputs_; del labels_; del loss_

            if args.arch == 'bert-base-uncased_pt' and args.optim == 'AdamW':
                # Toggle this?
                if args.clip_grad_norm:
                    torch.nn.utils.clip_grad_norm_(net.parameters(),
                                                args.max_grad_norm)
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()
                # optimizer.step()
                net.zero_grad()
            elif scheduler is not None:
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()
                net.zero_grad()
            else:
                optimizer.step()
                optimizer.zero_grad()
            
        
    return running_loss, correct, total, correct_by_groups, total_by_groups


def evaluate(net, dataloader, criterion, args, testing=False, return_losses=False):
    if args.dataset == 'civilcomments':
        return evaluate_civilcomments(net, dataloader, criterion, args)
    
    
    # Validation
    running_loss = 0.0
    all_losses = []
    correct = 0
    total = 0

    targets_s = dataloader.dataset.targets_all['spurious'].astype(int)
    targets_t = dataloader.dataset.targets_all['target'].astype(int)

    correct_by_groups = np.zeros([len(np.unique(targets_t)),
                                  len(np.unique(targets_s))])
    auroc_by_groups = np.zeros([len(np.unique(targets_t)),
                                len(np.unique(targets_s))])
    total_by_groups = np.zeros(correct_by_groups.shape)
    losses_by_groups = np.zeros(correct_by_groups.shape)

    correct_indices = []
    net.to(args.device)
    net.eval()

    with torch.no_grad():
        all_probs = []
        all_targets = []
        for i, data in enumerate(dataloader):
            inputs, labels, data_ix = data
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)

            labels_spurious = [targets_s[ix] for ix in data_ix]

            # Add this here to generalize NLP, CV models
            outputs = get_output(net, inputs, labels, args)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            all_correct = (predicted == labels).detach().cpu()
            correct += all_correct.sum().item()
            # loss_r = loss.mean() if return_losses else loss
            loss_r = loss.mean()
            running_loss += loss_r.item()
            all_losses.append(loss.detach().cpu().numpy())
            
            # For AUROC
            if args.compute_auroc is True:
                print(labels)
                print(F.softmax(outputs, dim=1).detach().cpu()[:, 1])
                print((F.softmax(outputs, dim=1).detach().cpu()[:, 1]).shape)
                all_probs.append(F.softmax(outputs, dim=1).detach().cpu()[:, 1])  # For AUROC
                all_targets.append(labels.detach().cpu())

            correct_indices.append(all_correct.numpy())

            if testing:
                for ix, s in enumerate(labels_spurious):
                    y = labels.detach().cpu().numpy()[ix]
                    correct_by_groups[int(y)][int(s)] += all_correct[ix].item()
                    total_by_groups[int(y)][int(s)] += 1
                    if return_losses:
                        losses_by_groups[int(y)][int(s)] += loss[ix].item()
            inputs = inputs.to(torch.device('cpu'))
            labels = labels.to(torch.device('cpu'))
            outputs = outputs.to(torch.device('cpu'))
            loss = loss.to(torch.device('cpu'))
            loss_r = loss_r.to(torch.device('cpu'))
            del inputs; del labels; del outputs
            
        if args.compute_auroc is True:
            targets_cat, probs_cat = torch.cat(all_targets), torch.cat(all_probs)
            auroc = compute_roc_auc(targets_cat, probs_cat)
            
            malignant_indices = np.where(targets_t == 1)[0]
            for i in range(len(auroc_by_groups[1])):
                auroc_by_groups[1][i] = auroc
            
            benign_indices = np.where(targets_t == 0)[0]
            for s in np.unique(targets_s[benign_indices]):
                spurious_indices = np.where(targets_s[benign_indices] == s)[0]
                paired_auroc_indices = np.union1d(malignant_indices,
                                                  benign_indices[spurious_indices])
                auroc = compute_roc_auc(targets_cat[paired_auroc_indices],
                                        probs_cat[paired_auroc_indices])
                auroc_by_groups[0][s] = auroc
                
            args.auroc_by_groups = auroc_by_groups
            min_auroc = np.min(args.auroc_by_groups.flatten())
            print('-' * 18)
            print(f'AUROC by group:')
            for yix, y_group in enumerate(auroc_by_groups):
                for aix, a_group in enumerate(y_group):
                    print(f'{yix}, {aix}  auroc: {auroc_by_groups[yix][aix]:>5.3f}')
            try:
                if min_auroc > args.robust_auroc:
                    print(f'- New max robust AUROC: {min_auroc:<.3f}')
                    args.robust_auroc = min_auroc
            except:
                print(f'- New max robust AUROC: {min_auroc:<.3f}')
                args.robust_auroc = min_auroc
                
    if testing:
        # if return_losses:
        #     all_losses = np.concatenate(all_losses)
        #     return running_loss, correct, total, correct_by_groups, total_by_groups, correct_indices, all_losses, losses_by_groups
        return running_loss, correct, total, correct_by_groups, total_by_groups, correct_indices
    return running_loss, correct, total, correct_indices


def evaluate_civilcomments(net, dataloader, criterion, args):
    dataset = dataloader.dataset
    metadata = dataset.metadata_array
    correct_by_groups = np.zeros([2, len(dataset._identity_vars)])
    total_by_groups = np.zeros(correct_by_groups.shape)
    
    identity_to_ix = {}
    for idx, identity in enumerate(dataset._identity_vars):
        identity_to_ix[identity] = idx
    
    for identity_var, eval_grouper in zip(dataset._identity_vars, 
                                          dataset._eval_groupers):
        group_idx = eval_grouper.metadata_to_group(metadata).numpy()
        
        g_list, g_counts = np.unique(group_idx, return_counts=True)
        print(identity_var, identity_to_ix[identity_var])
        print(g_counts)
        
        for g_ix, g in enumerate(g_list):
            g_count = g_counts[g_ix]
            # Only pick from positive identities
            # e.g. only 1 and 3 from here:
            #   0 y:0_male:0
            #   1 y:0_male:1
            #   2 y:1_male:0
            #   3 y:1_male:1
            n_total = g_counts[g_ix]  #  + g_counts[3]
            if g in [1, 3]:
                class_ix = 0 if g == 1 else 1  # 1 y:0_male:1
                print(g_ix, g, n_total)
    
    net.to(args.device)
    net.eval()
    total_correct = 0
    with torch.no_grad():
        all_predictions = []
        all_correct = []
        for i, data in enumerate(dataloader):
            inputs, labels, data_ix = data
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)

            # Add this here to generalize NLP, CV models
            outputs = get_output(net, inputs, labels, args)
            _, predicted = torch.max(outputs.data, 1)
            correct = (predicted == labels).detach().cpu()
            total_correct += correct.sum().item()
            all_correct.append(correct)
            all_predictions.append(predicted.detach().cpu())
            
            inputs = inputs.to(torch.device('cpu'))
            labels = labels.to(torch.device('cpu'))
            outputs = outputs.to(torch.device('cpu'))
            del inputs; del labels; del outputs
        
        all_correct = torch.cat(all_correct).numpy()
        all_predictions = torch.cat(all_predictions)
    
    # Evaluate predictions
    dataset = dataloader.dataset
    y_pred = all_predictions  # torch.tensors
    y_true = dataset.y_array
    metadata = dataset.metadata_array
    
    correct_by_groups = np.zeros([2, len(dataset._identity_vars)])
    total_by_groups = np.zeros(correct_by_groups.shape)
    
    for identity_var, eval_grouper in zip(dataset._identity_vars, 
                                          dataset._eval_groupers):
        group_idx = eval_grouper.metadata_to_group(metadata).numpy()
        
        g_list, g_counts = np.unique(group_idx, return_counts=True)
        print(g_counts)
        
        idx = identity_to_ix[identity_var]
        
        for g_ix, g in enumerate(g_list):
            g_count = g_counts[g_ix]
            # Only pick from positive identities
            # e.g. only 1 and 3 from here:
            #   0 y:0_male:0
            #   1 y:0_male:1
            #   2 y:1_male:0
            #   3 y:1_male:1
            n_total = g_count  # s[1] + g_counts[3]
            if g in [1, 3]:
                n_correct = all_correct[group_idx == g].sum()
                class_ix = 0 if g == 1 else 1  # 1 y:0_male:1
                correct_by_groups[class_ix][idx] += n_correct
                total_by_groups[class_ix][idx] += n_total
    return 0, total_correct, len(dataset), correct_by_groups, total_by_groups, None
