import argparse
import logging

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib.patches import Rectangle
from pytorch_transformers import AdamW, WarmupLinearSchedule
from torchvision.models import resnet18 as resnet18_torchvision
from torchvision.models import resnet50, wide_resnet50_2
from tqdm import tqdm

from densenet_model import densenet121
from models import cnn, resnet18


def save_activations(model, dataloader, args, pred_only=False):
    """
    total_embeddings = save_activations(net, train_loader, args)
    """
    save_output = SaveOutput()

    if not pred_only:
        total_embeddings = []
        hook_handles = []

        if 'resnet' in args.arch:
            for name, layer in model.named_modules():
                if name == model.activation_layer or \
                    (isinstance(model, torch.nn.DataParallel) and \
                    name.replace('module.', '') == model.activation_layer):
                    handle = layer.register_forward_hook(save_output)
                    hook_handles.append(handle)
        elif 'bert' in args.arch:
            for name, layer in model.named_modules():
                if name == model.activation_layer or \
                    (isinstance(model, torch.nn.DataParallel) and \
                    name.replace('module.', '') == model.activation_layer):
                    handle = layer.register_forward_hook(save_output)
                    hook_handles.append(handle)
                    args.logger.info(f'Activation layer: {name}')
        else:
            # Only get last activation layer that fits the criteria?
            activation_layers = []
            for layer in model.modules():
                try:
                    if isinstance(layer, torch.nn.ReLU) or isinstance(layer, torch.nn.Identity):
                        activation_layers.append(layer)
                except AttributeError:
                    if isinstance(layer, torch.nn.ReLU):
                        activation_layers.append(layer)
            # Only get last activation layer that fits the criteria
            if 'cnn' in args.arch and args.no_projection_head is False: 
                handle = activation_layers[-2].register_forward_hook(save_output)
            else:
                handle = activation_layers[-1].register_forward_hook(save_output)
            hook_handles.append(handle) 
    model.to(args.device)
    model.eval()

    # Forward pass on test set to save activations
    total_train = 0
    total_outputs = []
    total_predictions = []
    total_labels = []

    args.logger.info('> Saving activations')

    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader, desc='Running inference')):
            inputs, labels, data_ix = data
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            
            outputs = get_output(model, inputs, labels, args)

            predicted = torch.softmax(outputs.data, 1)
            total_train += labels.size(0)
            
            # Clear memory
            outputs = outputs.detach().cpu()
            predicted = predicted.detach().cpu()
            labels = labels.detach().cpu()
            total_outputs.append(outputs)
            total_predictions.append(predicted)
            total_labels.append(labels)
            del outputs; del predicted; del labels
        
    # Testing this
    save_output.outputs = [so.detach() for so in save_output.outputs]
    
    total_predictions = np.concatenate(total_predictions)
    total_outputs = np.concatenate(total_outputs)
    total_labels = np.concatenate(total_labels)

    if not pred_only:
        # Consolidate embeddings
        total_embeddings = [None] * len(save_output.outputs)

        for ix, output in enumerate(save_output.outputs):
            total_embeddings[ix] = output.numpy().squeeze()
            
        total_embeddings = np.concatenate(total_embeddings)
        if len(total_embeddings.shape) > 2:  # Should just be (n_datapoints, embedding_dim)
            total_embeddings = total_embeddings.reshape(len(total_embeddings), -1)
        save_output.clear()
        del save_output; del hook_handles
        return total_embeddings, total_predictions, total_outputs, total_labels
    else:
        return total_predictions


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def load_model(args, infer=False, t_total=0):
    # set random seed
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    np.random.seed(args.seed)

    if infer and args.infer_arch != '':
        if 'resnet' in args.infer_arch:
            arch = args.infer_arch[:8]
        else:
            arch = args.infer_arch
    else:
        arch = args.arch

    if arch[:6] == 'resnet':
        if arch == 'resnet50':
            if args.wide or (infer and 'wide' in args.infer_arch):
                model = wide_resnet50_2(pretrained=args.pretrained)
            else:
                model = resnet50(pretrained=args.pretrained)
        elif arch == 'resnet18':
            args.pretrained = False
            if 'cifar' in args.dataset:
                model = resnet18()
            else:
                model = resnet18_torchvision(pretrained=args.pretrained)
        if args.freeze:
            for param in model.parameters():
                param.requires_grad = False

        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, args.num_classes)
        model.activation_layer = 'avgpool'
    elif arch == 'densenet':
        model = densenet121(num_classes=args.num_classes)
        model.activation_layer = 'adaptive_avgpool'
    elif arch == 'lenet5':
        model = cnn(args.num_classes)
    elif args.arch == 'bert':
        from pytorch_transformers import (BertConfig,
                                          BertForSequenceClassification)
        config_class = BertConfig
        model_class = BertForSequenceClassification

        config = config_class.from_pretrained(
            'bert-base-uncased',
            num_labels=3,
            finetuning_task='mnli')
        model = model_class.from_pretrained(
            'bert-base-uncased',
            from_tf=False,
            config=config)
        model.activation_layer = 'bert.pooler.activation'
    else:
        raise NotImplementedError
    
    # data parallel
    if len(args.gpu) > 1:
        model = nn.DataParallel(model)
    model = model.to(args.device)

    if infer:
        lr = args.infer_lr
        wd = args.infer_weight_decay
        optimizer_name = args.infer_optimizer
    else:
        lr = args.lr
        wd = args.weight_decay
        optimizer_name = args.optimizer

    # set up optimizer
    if optimizer_name == 'sgd':
        args.logger.info(f'Using SGD with lr {lr}, momentum {args.momentum}, weight decay {wd}')
        if arch == 'bert':
            optimizer = torch.optim.SGD(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=lr,
                momentum=args.momentum,
                nesterov=args.nesterov,
                weight_decay=wd,
            )
        else:
            optimizer = optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=wd)
        # set up scheduler
        if args.scheduler != 'none':
            if args.scheduler == 'milestone':
                scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma, verbose=True)
            else:
                raise NotImplementedError
        else:
            scheduler = None
    elif optimizer_name == 'adamw':
        args.logger.info(f'Using AdamW with lr {lr}, weight decay {wd}')
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': wd},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=lr,
            eps=args.adam_epsilon)
        print(f'\nt_total is {t_total}\n')
        scheduler = WarmupLinearSchedule(
            optimizer,
            warmup_steps=args.warmup_steps,
            t_total=t_total)
    else:
        raise NotImplementedError

    return model, optimizer, scheduler


def set_logger(args, logger):
    logging.basicConfig(
        filename=f"{args.save_dir}/log.txt",
        filemode='w',
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)

    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(ch)

    return logger


# define an algorithm class that stores and update group weights based on group loss, and returns loss weighted by group weights
class GroupWeightedLoss(nn.Module):
    def __init__(self, args, num_groups):
        super(GroupWeightedLoss, self).__init__()
        self.args = args
        self.num_groups = num_groups
        self.group_weights = torch.ones(self.num_groups).to(args.device)
        self.group_weights.data = self.group_weights.data / self.group_weights.data.sum()

    def forward(self, loss, groups):
        # compute loss for different groups and update group weights
        group_loss = torch.zeros(self.num_groups).to(self.args.device)
        for i in range(self.num_groups):
            if (groups==i).sum() > 0:
                group_loss[i] += loss[groups==i].mean()
        self.update_group_weights(group_loss)

        # compute weighted loss
        loss = group_loss * self.group_weights
        loss = loss.sum()
        
        return loss

    def update_group_weights(self, group_loss):
        group_weights = self.group_weights
        group_weights = group_weights * torch.exp(self.args.group_weight_lr * group_loss)
        group_weights = group_weights / group_weights.sum()
        self.group_weights.data = group_weights.data


class SaveOutput:
    def __init__(self):
        self.outputs = []

    def __call__(self, module, module_in, module_out):
        try:
            module_out = module_out.detach().cpu()
            self.outputs.append(module_out)  # .detach().cpu().numpy()
        except Exception as e:
            print(e)
            self.outputs.append(module_out)

    def clear(self):
        self.outputs = []


def get_output(model, inputs, labels, args):
    """
    General method for BERT and non-BERT model inference
    - Model and data batch should be passed already
    
    Args:
    - model (torch.nn.Module): Pytorch network
    - inputs (torch.tensor): Data features batch
    - labels (torch.tensor): Data labels batch
    - args (argparse): Experiment args
    """
    if args.arch == 'bert':
        input_ids   = inputs[:, :, 0]
        input_masks = inputs[:, :, 1]
        segment_ids = inputs[:, :, 2]
        outputs = model(input_ids=input_ids,
                        attention_mask=input_masks,
                        token_type_ids=segment_ids,
                        labels=labels)
        if labels is None:
            return outputs.logits
        return outputs[1]  
    else:
        return model(inputs)


def plot_confusion_matrix(groups, cluster_labels, args):
    from sklearn.metrics import confusion_matrix

    # print number of groups and number of clusters
    num_groups = len(np.unique(groups))
    num_clusters = len(np.unique(cluster_labels))
    args.logger.info(f"Number of groups: {num_groups}")
    args.logger.info(f"Number of clusters: {num_clusters}")

    # find the best mapping between groups and clusters and plot confusion matrix based on the mapping
    from scipy.optimize import linear_sum_assignment
    cm = confusion_matrix(groups, cluster_labels)
    row_ind, col_ind = linear_sum_assignment(-cm)
    args.logger.info(f"Best mapping: {dict(zip(row_ind, col_ind))}")

    if 'cmnist' in args.dataset:
        fig, ax = plt.subplots(figsize=(16, 16))
    else:
        fig, ax = plt.subplots(figsize=(8, 8))
    sns.set(font_scale=1.5)

    # plot confusion matrix with warm cmap
    sns.heatmap(cm[row_ind[:, np.newaxis], col_ind], annot=True, fmt='d', ax=ax, cmap='Oranges', cbar=False)
    ax.set_xlabel('Predicted groups', fontsize=20)
    ax.set_ylabel('True groups', fontsize=20)
    if args.dataset == 'waterbirds':
        ax.xaxis.set_ticklabels(['Landbird\ndownweight', 'Landbird\nupweight', 'Waterbird\nupweight', 'Waterbird\ndownweight'])
        ax.yaxis.set_ticklabels(['Landbird\nland', 'Landbird\nwater', 'Waterbird\nland', 'Waterbird\nwater'])
    elif args.dataset == 'celeba':
        ax.xaxis.set_ticklabels(['Dark\nupsample', 'Dark\ndownsample', 'Blonde\ndownsample', 'Blonde\nupsample'])
        ax.yaxis.set_ticklabels(['Dark\nfemale', 'Dark\nmale', 'Blonde\nfemale', 'Blonde\nmale'])
    else:
        ax.xaxis.set_ticklabels(np.unique(groups))
        ax.yaxis.set_ticklabels(np.unique(groups))

    # make all fonts in the figure larger
    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
            item.set_fontsize(18)

    # highlight the diagonal with red box
    for i in range(num_clusters):
        ax.add_patch(Rectangle((i, i), 1, 1, fill=False, edgecolor='red', lw=3))

    fig.tight_layout()

    # reindex cluster labels based on the best mapping
    cluster_labels = np.array([col_ind[i] for i in cluster_labels])
    
    return fig, cluster_labels


def visualize_traject(traject, groups, dataset, args):
    traject = np.stack(traject, axis=-1)
    traject = np.amax(traject, axis=1)

    # plot trajectory by group with mean and std
    fig, ax = plt.subplots(figsize=(10, 10))

    # build a dataframe for plotting
    df = pd.DataFrame(traject)
    df['group'] = groups
    df = df.melt(id_vars='group', var_name='step', value_name='value')
    sns.lineplot(x='step', y='value', hue='group', data=df, ax=ax)

    ax.set_xlabel('Step')
    ax.set_ylabel('Confidence')
    ax.set_title('Trajectory of Confidence')

    # put an example image for each group in the legend
    handles, labels = ax.get_legend_handles_labels()
    # for each group, get the first image in the dataset
    for i in np.unique(groups):
        group_index = np.where(groups==i)[0][0]
        image = dataset[group_index][0]
        image = image.permute(1, 2, 0).numpy()
        image = (image * 255).astype(np.uint8)
        handles.append(image)

    # add the images to the legend
    ax.legend(handles=handles, labels=labels, loc='center left', bbox_to_anchor=(1, 0.5))

    # make all fonts in the figure larger
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                    ax.get_xticklabels() + ax.get_yticklabels()):
            item.set_fontsize(20)

    fig.tight_layout()
    plt.savefig(f"{args.save_dir}/group_trajectory.png")

    return fig


# define a parser for command line arguments
def get_parser():
    parser = argparse.ArgumentParser()

    # add dataset, model, optimizer, and training options
    parser.add_argument('--dataset', type=str, default='waterbirds', 
                        choices=['waterbirds', 'celeba', 'cmnist', 'multinli', 
                                 'civilcomments', 'cifar100sup', 'cifar10', 'imagenet', 'balance_cmnist'])
    # add cmnist options
    parser.add_argument('--p_correlation', type=float, default=0.995)

    # add model options
    parser.add_argument('--arch', type=str, default='resnet50', choices=['resnet18', 'resnet50', 'bert', 'mlp', 'densenet'])
    parser.add_argument('--infer_arch', type=str, default='', choices=['', 'resnet18', 'resnet50', 'resnet50wide', 'densenet'])
    parser.add_argument('--pretrained', type=bool, default=False)
    parser.add_argument('--freeze', type=bool, default=False, help='Freeze the pretrained model')
    parser.add_argument('--wide', type=bool, default=False, help='Use a wider model')

    # add optimizer options
    parser.add_argument('--optimizer', type=str, default='sgd', choices=['sgd', 'adamw'])
    parser.add_argument('--infer_optimizer', type=str, default='sgd', choices=['sgd', 'adamw'])
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--nesterov', type=bool, default=True)
    parser.add_argument('--weight_decay', type=float, default=1e-3)
    parser.add_argument('--infer_lr', type=float, default=1e-3)
    parser.add_argument('--infer_weight_decay', type=float, default=1e-4)
    parser.add_argument('--eiil_lr', type=float, default=1e-2)
    parser.add_argument('--eiil_steps', type=int, default=20000)
    # add scheduler options
    parser.add_argument('--scheduler', type=str, default='none', choices=['none', 'milestone'])
    parser.add_argument('--milestones', type=int, nargs='+', default=[30, 60, 80])
    parser.add_argument('--gamma', type=float, default=0.1)

    # add training options
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--save_dir', type=str, default='./outputs')
    parser.add_argument('--save_unit', type=str, default='epoch', choices=['epoch', 'batch'])
    parser.add_argument('--checkpoint_path', type=str, default='./saved_models')
    parser.add_argument('--seed', type=int, default=11111111)
    parser.add_argument('--gpu', type=int, nargs='+', default=[0])

    # add spurious correlation inference and training options
    parser.add_argument('--infer_steps', type=int, nargs='+', default=[])
    parser.add_argument('--infer', type=str, default='none', 
                        choices=['none', 'misclass', 'eiil', 'cluster', 'ssa', 'conf_thresh', 'linear_embed', 'dfr'])
    parser.add_argument('--dfr_val', type=bool, default=False)
    parser.add_argument('--cluster_method', type=str, default='kmeans', choices=['kmeans', 'gmm', 'submod'])
    parser.add_argument('--cluster_metric', type=str, default='pred', choices=['pred', 'embed', 'conf', 'grad', 'true_conf', 'logit'])
    parser.add_argument('--cluster_all', type=bool, default=False)
    parser.add_argument('--silhouette', type=bool, default=False)
    parser.add_argument('--sample_by_silhouette', type=bool, default=False)
    parser.add_argument('--max_clusters', type=int, default=10)
    parser.add_argument('--reuse_clusters', type=bool, default=False)
    parser.add_argument('--num_clusters', type=int, default=2)
    parser.add_argument('--train', type=str, default='erm', choices=['erm', 'group_dro'])
    parser.add_argument('--group_weight_lr', type=float, default=0.01)
    parser.add_argument('--sample', type=str, default='none', choices=['none', 'upsample', 'upsample_by_factor', 'downsample'])
    parser.add_argument('--upsample_factor', type=int, default=100)
    parser.add_argument('--uniform_class_sampler', type=bool, default=False)
    parser.add_argument('--uniform_group_sampler', type=bool, default=False)
    parser.add_argument('--weighted_sampler', type=bool, default=False)
    parser.add_argument('--equal', type=bool, default=True)
    parser.add_argument('--continue_train', type=bool, default=False, help='continue training from the inferred model')
    parser.add_argument('--traject_length', type=int, default=1)
    parser.add_argument('--sep_conf', type=bool, default=False)
    parser.add_argument('--conf_thresh', type=float, default=0.5)

    # add options for data augmentation to use mixup or cutmix
    parser.add_argument('--mixup', type=str, default='none', choices=['none', 'mixup', 'cutmix'])
    parser.add_argument('--mixup_alpha', type=float, default=1.0)
    parser.add_argument('--mixup_lisa', type=bool, default=False)
    parser.add_argument('--mixup_sample_by_cluster', type=bool, default=False)
    parser.add_argument('--mixup_sample_by_cluster_power', type=float, default=1.0)

    # add option to use standard data augmentation
    parser.add_argument('--train_augment', type=bool, default=False)
    parser.add_argument('--infer_augment', type=bool, default=False)

    # add options to weight the loss of each group by the number of samples in each group
    parser.add_argument('--weight_loss', type=bool, default=False)
    parser.add_argument('--weight_loss_power', type=float, default=1.0)

    # add options to sample from each group by the number of samples in each group
    parser.add_argument('--sample_by_cluster', type=bool, default=False)
    parser.add_argument('--sample_by_cluster_power', type=float, default=1.0)
    parser.add_argument('--adaptive_sample_power', type=bool, default=False)
    parser.add_argument('--upsample_by_cluster_size', type=bool, default=False)
    # add sample_every option to sample every n epochs
    parser.add_argument('--sample_every', type=int, default=-1)

    # add visualization options
    parser.add_argument('--visualize', type=bool, default=False)
    parser.add_argument('--cmap', type=str, default='tab10')
    parser.add_argument('--cluster_umap', type=bool, default=False)

    # infer only
    parser.add_argument('--infer_only', type=bool, default=False)
    parser.add_argument('--balance_classes_infer', type=bool, default=False)
    parser.add_argument('--num_infer_ckpts', type=int, default=1)
    parser.add_argument('--include_init', type=bool, default=False)
    parser.add_argument('--class_dro_infer', type=bool, default=False)
    parser.add_argument('--infer_loss_thresh', type=float, default=-1)
    parser.add_argument('--infer_loss_diff', type=bool, default=False)
    parser.add_argument('--sep_infer', type=bool, default=False)
    parser.add_argument('--adaptive_conf_thresh', type=bool, default=False)
    parser.add_argument('--linear_embed_val', type=bool, default=False)

    # SSA options
    parser.add_argument('--num_splits', type=int, default=3)
    parser.add_argument('--num_iters', type=int, default=1000) # 1000 for Waterbirds, 45k for CelebA
    parser.add_argument('--smallest_group_conf_thresh', type=float, default=0.95)
    parser.add_argument('--val_freq', type=int, default=10)

    # add wandb
    parser.add_argument('--use_wandb', type=bool, default=False)
    parser.add_argument('--wandb_project', type=str, default='early_trajectory')

    return parser