"""
Functions for slicing data

NOTE: Going to refactor this with slice_train.py and spurious_train.py
      - Currently methods support different demos / explorations
"""
import copy
import numpy as np
import torch
from torch.utils.data import DataLoader, SequentialSampler, SubsetRandomSampler
from tqdm import tqdm

from datasets import train_val_split, get_resampled_set, get_resampled_indices
from train import train_model, test_model
from network import get_criterion, get_optim, get_net, get_output
from utils.logging import summarize_acc


def compute_slice_indices(net, dataloader, criterion, 
                          batch_size, args, resample_by='class',
                          loss_factor=1., use_dataloader=False):
    """
    Use trained model to predict "slices" of data belonging to different subgroups

    Args:
    - net (torch.nn.Module): Pytorch neural network model
    - dataloader (torch.nn.utils.DataLoader): Pytorch data loader
    - criterion (torch.nn.Loss): Pytorch cross-entropy loss (with reduction='none')
    - batch_size (int): Batch size to compute slices over
    - args (argparse): Experiment arguments
    - resamble_by (str): How to resample, ['class', 'correct']
    Returns:
    - sliced_data_indices (int(np.array)[]): List of numpy arrays denoting indices of the dataloader.dataset
                                             corresponding to different slices
    """
    # First compute pseudolabels
    dataloader_ = dataloader if use_dataloader else None
    dataset = dataloader.dataset
    
    if args.pred_labels_path is not None or args.pred_outputs_path is not None:
        # Supply predictions from external file
        slice_outputs = compute_pseudolabels_from_path(dataset, args)
    else:
        slice_outputs = compute_pseudolabels(net, dataset, 
                                             batch_size, args,  # Added this dataloader
                                             criterion, dataloader=dataloader_)
    pseudo_labels, outputs, correct, correct_spurious, losses = slice_outputs
    
    targets = dataloader.dataset.targets_all['target']
    unique_classes = np.unique(targets)
    unique_preds = np.unique(pseudo_labels)
    dataloader.dataset.targets_all['pred_groups'] = np.zeros(len(targets)).astype(int)
    pred_group_idx = 0
    for ix, class_label in enumerate(unique_classes):
        class_indices = np.where(targets == class_label)[0]
        for jx, pred_label in enumerate(unique_preds):
            pred_indices = np.where(pseudo_labels == pred_label)[0]
            group_indices = np.intersect1d(pred_indices, class_indices)
            dataloader.dataset.targets_all['pred_groups'][group_indices] = pred_group_idx
            pred_group_idx += 1
    
    output_probabilities = torch.exp(outputs) / torch.exp(outputs).sum(dim=1).unsqueeze(dim=1)

    sliced_data_indices = []
    all_losses = []
    all_correct = []
    correct = correct.detach().cpu().numpy()
    all_probs = []
    for label in np.unique(pseudo_labels):
        group = np.where(pseudo_labels == label)[0]
        if args.weigh_slice_samples_by_loss:
            losses_per_group = losses[group]
        correct_by_group = correct[group]
        probs_by_group = output_probabilities[group]
        if args.subsample_labels is True or args.supersample_labels is True:
            group_vals = np.unique(dataloader.dataset.targets[group],
                                   return_counts=True)[1]
            sample_size = (np.min(group_vals) if args.subsample_labels is True
                           else np.max(group_vals))
            sampled_indices = []
            # These end up being the same
            if resample_by == 'class':
                target_values = dataloader.dataset.targets[group]
            elif resample_by == 'correct':
                target_values = correct_by_group
            # assert correct_by_group == dataloader.dataset.targets[group]
            print(f'> Resampling by {resample_by}...')
            for v in np.unique(target_values):
                group_indices = np.where(target_values == v)[0]
                if args.subsample_labels is True:
                    sampling_size = np.min([len(group_indices), sample_size])
                    replace = False
                    p = None
                elif args.supersample_labels is True:
                    sampling_size = np.max(
                        [0, sample_size - len(group_indices)])
                    sampled_indices.append(group_indices)
                    replace = True
                    if args.weigh_slice_samples_by_loss:
                        p = losses_per_group[group_indices] * loss_factor
                        p = (torch.exp(p) / torch.exp(p).sum()).numpy()
                    else:
                        p = None
                sampled_indices.append(np.random.choice(
                    group_indices, size=sampling_size, replace=replace, p=p)) 
            sampled_indices = np.concatenate(sampled_indices)
            sorted_indices = np.arange(len(sampled_indices))
            if args.weigh_slice_samples_by_loss:
                all_losses.append(losses_per_group[sampled_indices][sorted_indices])
            sorted_indices = np.arange(len(sampled_indices))
            sliced_data_indices.append(group[sampled_indices][sorted_indices])
            all_correct.append(correct_by_group[sampled_indices][sorted_indices])
            all_probs.append(probs_by_group[sampled_indices][sorted_indices])
        else:
            if args.weigh_slice_samples_by_loss:
                sorted_indices = torch.argsort(losses_per_group, descending=True)
                all_losses.append(losses_per_group[sorted_indices])
            else:
                sorted_indices = np.arange(len(group))
            sliced_data_indices.append(group[sorted_indices])
            all_correct.append(correct_by_group[sorted_indices])
            all_probs.append(probs_by_group[sorted_indices])
    # Save GPU memory
    for p in net.parameters():
        p = p.detach().cpu() 
    net.to(torch.device('cpu')) 
    return sliced_data_indices, all_losses, all_correct, all_probs


# Util for ground-truth ablation
def get_one_hot(a, num_classes):
    return torch.squeeze(torch.eye(num_classes)[a.reshape(-1)])


def compute_pseudolabels(net, dataset, batch_size, args, criterion=None, 
                         dataloader=None):
    net.eval()
    if dataloader is None:
        new_loader = DataLoader(dataset, batch_size=batch_size,
                                shuffle=False, num_workers=args.num_workers)
    else:
        new_loader = dataloader
        dataset = dataloader.dataset
    all_outputs = []
    all_predicted = []
    all_correct = []
    all_correct_spurious = []
    all_losses = []
    net.to(args.device)
    
    num_classes = len(args.train_classes)

    with torch.no_grad():
        targets_s = dataset.targets_all['spurious']
        for batch_ix, data in enumerate(tqdm(new_loader)):
            inputs, labels, data_ix = data
            labels_spurious = torch.tensor(
                [targets_s[ix] for ix in data_ix]).to(args.device)

            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            
            # For ablation studying dependence on ERM predictions
            if args.slice_with == 'true':
                predicted = labels_spurious
                outputs = get_one_hot(predicted, num_classes).to(args.device)
            else:
                outputs = get_output(net, inputs, labels, args)
                _, predicted = torch.max(outputs.data, 1)
            all_outputs.append(outputs.detach().cpu())
                
            # Add noise to ERM predictions for ablation
            if args.slice_noise > 0:
                replace_mask = np.random.choice(2, len(predicted), # flipped actually with noise
                                                p=[1 - args.slice_noise, args.slice_noise])
                replace_mask = torch.tensor(replace_mask, dtype=torch.bool).to(args.device)
                # With probability = args.slice_noise, replace the predicted value with a 
                # uniformly sampled target choice
                replace = torch.randint(0, num_classes, replace_mask.shape).to(args.device)
                predicted = predicted.where(replace_mask, replace)
            
            all_predicted.append(predicted.detach().cpu())
            if args.weigh_slice_samples_by_loss:
                assert criterion is not None, 'Need to specify criterion'
                loss = criterion(outputs, labels)
                all_losses.append(loss.detach().cpu())

            # Save correct
            correct = (predicted == labels).to(torch.device('cpu'))
            correct_spurious = (predicted == labels_spurious).to(torch.device('cpu'))
            all_correct.append(correct)
            all_correct_spurious.append(correct_spurious)
            
            inputs = inputs.to(torch.device('cpu'))
            labels = labels.to(torch.device('cpu'))
            outputs = outputs.to(torch.device('cpu'))
            predicted = predicted.to(torch.device('cpu'))

    pseudo_labels = torch.hstack(all_predicted)
    outputs = torch.vstack(all_outputs)
    correct = torch.hstack(all_correct)
    correct_spurious = torch.hstack(all_correct_spurious)
    if len(all_losses) > 0:
        all_losses = torch.hstack(all_losses)
    else:
        all_losses = None
    return pseudo_labels, outputs, correct, correct_spurious, all_losses


def compute_pseudolabels_from_path(dataset, args):
    criterion = torch.nn.CrossEntropyLoss(reduction='none')
    targets_t = dataset.targets
    targets_s = torch.tensor(dataset.targets_all['spurious'])
    
    if args.pred_outputs_path is not None:
        with open(args.pred_outputs_path, 'rb') as f:
            outputs = torch.tensor(np.load(f))
            pseudo_labels = torch.argmax(outputs, dim=1)
            
            correct = pseudo_labels == targets_t
            correct_spurious = pseudo_labels == targets_s
            losses = criterion(outputs, pseudo_labels)
        
    elif args.pred_labels_path is not None:
        raise NotImplementedError
    
    return pseudo_labels, outputs, correct, correct_spurious, losses


# Ablation for studying downstream performance based on ERM model's ablity to predict ground-truth correctly
def compute_slices_from_groundtruth(dataloader, args):
    """
    Split the data into slices based on the spurious attribute value
    """
    print('> Loading ground-truth labels...')
    dataset = dataloader.dataset
    sliced_data_indices = []
    all_losses = []
    all_correct = []
    all_probs = []

    for c in np.unique(dataset.targets_all['spurious']):
        spurious_indices = np.where(dataset.targets_all['spurious'] == c)[0]
        labels = dataset.targets_all['target'][spurious_indices]
        print(spurious_indices)
        sliced_data_indices.append(spurious_indices)
        all_losses.append(np.zeros(len(spurious_indices)))
        all_correct.append((np.ones(len(spurious_indices)) * c) == labels)
        all_probs.append(np.ones(len(spurious_indices)))
        
    targets = dataloader.dataset.targets_all['target']
    pseudo_labels = dataset.targets_all['spurious']
    unique_classes = np.unique(targets)
    unique_preds = np.unique(pseudo_labels)
    dataloader.dataset.targets_all['pred_groups'] = np.zeros(len(targets)).astype(int)
    pred_group_idx = 0
    for ix, class_label in enumerate(unique_classes):
        class_indices = np.where(targets == class_label)[0]
        for jx, pred_label in enumerate(unique_preds):
            pred_indices = np.where(pseudo_labels == pred_label)[0]
            group_indices = np.intersect1d(pred_indices, class_indices)
            dataloader.dataset.targets_all['pred_groups'][group_indices] = pred_group_idx
            pred_group_idx += 1

    return sliced_data_indices, all_correct, all_losses, None

    pseudo_labels = dataloader.dataset.targets_all['spurious']
    correct = pseudo_labels == dataloader.dataset.targets_all['target']
    
    sliced_data_indices = []
    all_losses = []  # Dummy
    all_correct = []
    
    for label in np.unique(pseudo_labels):
        group = np.where(pseudo_labels == label)[0]
        correct_by_group = correct[group]
        # Default
        if args.subsample_labels is False and args.supersample_labels is False:
            sliced_data_indices.append(group)
            all_correct.append(correct_by_group)
        else:
            group_vals = np.unique(dataloader.dataset.targets[group],
                                   return_counts=True)[1]
            sample_size = (np.min(group_vals) if args.subsample_labels is True
                           else np.max(group_vals))
            sampled_indices = []
            target_values = dataloader.dataset.targets[group]
            for v in np.unique(target_values):
                group_indices = np.where(target_values == v)[0]
                if args.subsample_labels is True:
                    sampling_size = np.min([len(group_indices), sample_size])
                    replace = False
                    # print(f'Subsample size: {sampling_size}')
                elif args.supersample_labels is True:
                    sampling_size = np.max(
                        [0, sample_size - len(group_indices)])
                    sampled_indices.append(group_indices)
                    replace = True
                    # print(f'Supersample size: {sampling_size}')
                sampled_indices.append(np.random.choice(
                    group_indices, size=sampling_size, replace=replace))  # p = p
            sampled_indices = np.concatenate(sampled_indices)
            sliced_data_indices.append(group[sampled_indices])
            all_correct.append(correct_by_group[sampled_indices])
            
    # Dummy versions for now
    all_losses = [np.zeros(len(x)) for x in sliced_data_indices]
    all_probs = None
    return sliced_data_indices, all_losses, all_correct, all_probs


# Supply slices from external file
def compute_slices_from_path(dataset, args):
    assert args.dataset in ['waterbirds'], 'Right now only supports Waterbirds'
    
    with open(args.pred_groups_path, 'rb') as f:
        group_array = np.load(f)
    print(f"Predicted group accuracy: {np.sum(group_array == dataset.targets_all['group_idx']) / len(dataset.targets_all['group_idx']) * 100:<.4f}%")

    # Map the groups to the spurious components
    # group 0 -> 0; group 1 -> 1; group 2 -> 0; group 3 -> 1
    group_to_prediction = [0, 1, 0, 1]
    pseudo_labels = np.vectorize(lambda x: group_to_prediction[int(x)])(group_array)
    correct = pseudo_labels == dataset.targets_all['target']

    sliced_data_indices = []
    all_losses = []
    all_correct = []
    all_probs = []

    for label in np.unique(pseudo_labels):
        group = np.where(pseudo_labels == label)[0]
        correct_by_group = correct[group]
        sliced_data_indices.append(group)
        all_correct.append(correct_by_group)
        all_losses.append(np.zeros(len(group)))

    return sliced_data_indices, all_correct, all_losses


def train_spurious_model(train_loader, args, resample=False,
                         return_loaders=False, test_loader=None,
                         test_criterion=None):
    """  # This is only training on some datapoints for the initial training data
    net = get_net(args)
    optim = get_optim(net, args, model_type='spurious')
    criterion = get_criterion(args)
    
    log_test_results = False  # True if test_loader is not None else False

    outputs = train_model(net, optim, criterion,
                          train_loader=train_loader,
                          val_loader=test_loader,
                          args=args, epochs=args.max_epoch_s,
                          log_test_results=log_test_results,
                          test_loader=test_loader,
                          test_criterion=test_criterion)
    
    if return_loaders:
        return net, outputs, (train_loader_new, train_loader_spurious)
    return net, outputs, None
    
    """
    # Debug this resampling thing
    train_indices, train_indices_spurious = train_val_split(train_loader.dataset,
                                                            val_split=args.spurious_train_split, 
                                                            seed=args.seed)
    train_targets_all = train_loader.dataset.targets_all
    unique_target_counts = np.unique(train_targets_all['target'][train_indices_spurious],
                                     return_counts=True)
    print(f'Target values in spurious training data: {unique_target_counts}')
    
    train_set_new = get_resampled_set(train_loader.dataset,
                                      train_indices,
                                      copy_dataset=True)
    train_set_spurious = get_resampled_set(train_loader.dataset,
                                           train_indices_spurious,
                                           copy_dataset=True)

    train_loader_new = DataLoader(train_set_new,
                                  batch_size=args.bs_trn,
                                  shuffle=False,
                                  num_workers=args.num_workers)
    train_loader_spurious = DataLoader(train_set_spurious,
                                       batch_size=args.bs_trn,
                                       shuffle=False,
                                       num_workers=args.num_workers)
    if resample is True:
        resampled_indices = get_resampled_indices(train_loader_spurious,
                                                  args,
                                                  args.resample_class)
        train_set_resampled = get_resampled_set(train_set_spurious,
                                                resampled_indices)
        train_loader_spurious = DataLoader(train_set_resampled,
                                           batch_size=args.bs_trn,
                                           shuffle=True,
                                           num_workers=args.num_workers)
        
    net = get_net(args)
    optim = get_optim(net, args, model_type='spurious')
    criterion = get_criterion(args)
    
    log_test_results = True if test_loader is not None else False

    outputs = train_model(net, optim, criterion,
                          train_loader=train_loader_spurious,
                          val_loader=train_loader_new,
                          args=args, epochs=args.max_epoch_s,
                          log_test_results=False,  # log_test_results,
                          test_loader=test_loader,
                          test_criterion=test_criterion)
    
    if return_loaders:
        return net, outputs, (train_loader_new, train_loader_spurious)
    return net, outputs, None


def train_batch_model(train_loader, sliced_data_indices, args,
                      val_loader, test_loader=None):
    """
    Train a single model with minibatch SGD aggregating and shuffling the sliced data indices - Updated with val loader
    """
    net = get_net(args, pretrained=False)
    optim = get_optim(net, args, model_type='pretrain')
    criterion = get_criterion(args)
    test_criterion = torch.nn.CrossEntropyLoss(reduction='none')
    indices = np.hstack(sliced_data_indices)
    heading = f'Training on aggregated slices'
    print('-' * len(heading))
    print(heading)
    sliced_val_loader = val_loader
    sliced_train_sampler = SubsetRandomSampler(indices)
    sliced_train_loader = DataLoader(train_loader.dataset,
                                     batch_size=args.bs_trn,
                                     sampler=sliced_train_sampler,
                                     num_workers=args.num_workers)
    args.model_type = 'mb_slice'
    train_model(net, optim, criterion, sliced_train_loader,
                sliced_val_loader, args, 0, args.max_epoch,
                True, test_loader, test_criterion)
    return net
