"""
Updated methods for training initial spurious model
- Also updated methods for slicing, via inspecting hidden layer representations
"""
import os
import copy
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import torchvision.transforms as transforms
from PIL import Image
from itertools import permutations
from tqdm import tqdm

# Representation-based slicing
import umap
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture

# Use a scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau


# Data
from torch.utils.data import DataLoader, SequentialSampler, SubsetRandomSampler
from datasets import get_data_args, train_val_split, get_resampled_indices, get_resampled_set
# from datasets.isic import load_isic, visualize_isic

# Logging and training
from slice import compute_pseudolabels, train_spurious_model, compute_slice_indices
from utils.logging import log_data, initialize_csv_metrics
from train import train_model, test_model, train, evaluate
from utils import print_header, init_experiment

from utils.logging import summarize_acc, log_data
from utils.visualize import plot_confusion, plot_data_batch

# Model
from network import get_net, get_optim, get_criterion, save_checkpoint
from activations import save_activations


def train_spurious_model(train_loader, val_loader, args, test_loader,
                         test_criterion, resample=False, return_loaders=False):
    train_targets_all = train_loader.dataset.targets_all
    unique_spurious_counts = np.unique(train_targets_all['spurious'],
                                       return_counts=True)
    unique_target_counts = np.unique(train_targets_all['target'],
                                     return_counts=True)
    print(f'Spurious values in spurious training data: {unique_spurious_counts}')
    print(f'Target values in spurious training data: {unique_target_counts}')
    
    net = get_net(args)
    optim = get_optim(net, args, model_type='spurious')
    scheduler = ReduceLROnPlateau(optim, 'max')
    criterion = get_criterion(args)
    
    min_running_val_loss = 1e8
    max_val_acc = 0
    early_stopping_counter = 0
    
    if resample is True:
        for epoch in range(args.max_epoch_s):
            net.train()
            net.to(args.device)
            resampled_indices = get_resampled_indices(train_loader,
                                                      args,
                                                      args.resample_class,
                                                      args.seed + epoch)
            train_set_resampled = get_resampled_set(train_loader.dataset,
                                                    resampled_indices)
            train_loader_rs = DataLoader(train_set_resampled,
                                         batch_size=args.bs_trn_s,
                                         shuffle=True,
                                         num_workers=args.num_workers)
            log_data(train_set_resampled, 
                     f'Resampled dataset at epoch {epoch}', indices=None)
            
            train_outputs = train(net, train_loader_rs, optim, criterion, args)
            running_loss, correct, total, correct_by_groups, total_by_groups = train_outputs
            val_outputs = evaluate(net, val_loader, criterion, args, testing=True)
            val_running_loss, val_correct, val_total, correct_by_groups_v, total_by_groups_v, correct_indices = val_outputs
            
            if (epoch + 1) % 1 == 0:
                print(f'Epoch: {epoch + 1:3d} | Train Loss: {running_loss / total:<.3f} | Train Acc: {100 * correct / total:<.3f} | Val Loss: {val_running_loss / val_total:<.3f} | Val Acc: {100 * val_correct / val_total:<.3f}')

            print('Training:')
            summarize_acc(correct_by_groups, total_by_groups)

            print('Validating:')
            summarize_acc(correct_by_groups_v, total_by_groups_v)

            group_acc = []
            for yix, y_group in enumerate(correct_by_groups_v):
                y_correct = []
                y_total = []
                for aix, a_group in enumerate(y_group):
                    if total_by_groups_v[yix][aix] > 0:
                        acc = a_group / total_by_groups_v[yix][aix]
                        if args.seed == 1:
                            if yix == aix:
                                y_correct.append(a_group)
                                y_total.append(total_by_groups_v[yix][aix])
                        else:
                            y_correct.append(a_group)
                            y_total.append(total_by_groups_v[yix][aix])
                group_acc.append(np.sum(y_correct) /
                                 np.sum(y_total))
            group_avg_acc = np.mean(group_acc)
            print(group_acc)
            print(group_avg_acc)
            scheduler.step(group_avg_acc)
                
            if group_avg_acc > max_val_acc:
                save_checkpoint(net, optim, running_loss,
                                epoch, batch=0, args=args,
                                replace=True, retrain_epoch=None)
                max_val_acc = group_avg_acc
                test_model(net, test_loader, test_criterion, args, epoch)
                early_stopping_counter = 0
            else:
                early_stopping_counter += 1
                if early_stopping_counter > 20:
                    break
                
        outputs = (val_running_loss, val_correct, val_total, correct_by_groups_v, total_by_groups_v, correct_indices)
    else:

        outputs = train_model(net, optim, criterion,
                              train_loader=train_loader,
                              val_loader=val_loader,
                              args=args, epochs=args.max_epoch_s,
                              checkpoint_interval=10)
    if return_loaders:
        return net, outputs, (train_loader, val_loader)
    return net, outputs, None


def resample_sliced_data_indices_by_probs(sliced_data_indices, sliced_data_probs, 
                                          train_targets, temp=1):
    resampled_sliced_data_indices = []
    all_resampled_indices = []
    for ix, probs in enumerate(sliced_data_probs):
        target_class, target_counts = np.unique(
            train_targets[sliced_data_indices[ix]], 
            return_counts=True)
        max_class_ix = np.argmax(target_counts)
        if max_class_ix == ix:  # Same as the predicted slice
            resampled_indices = np.arange(len(probs))
        else:
            # Check if probabilities should be <0.5 or >0.5
            sample_size = len(probs)
            probs = np.max(probs.numpy(), axis=1)
            if np.max(probs) <= 0.5:
                # In this case, upweight the lower probabilities
                probs = 1 - probs
            exp = np.exp(probs * temp)
            p = exp / exp.sum()
            resampled_indices = np.random.choice(np.arange(len(probs)),
                                                 size=sample_size,
                                                 p=p,
                                                 replace=True)
        plt.hist(resampled_indices, alpha=0.7, 
                 label=f'resampled_indices, slice {ix}')
        plt.legend()
        plt.show()
        resampled_sliced_data_indices.append(
            sliced_data_indices[ix][resampled_indices])
        all_resampled_indices.append(resampled_indices)
    return resampled_sliced_data_indices, all_resampled_indices


def get_resampled_sliced_data_indices(slice_outputs, train_loader, temp=10):
    sliced_data_indices, sliced_data_losses, sliced_data_correct, sliced_data_probs = slice_outputs
    train_targets_all = train_loader.dataset.targets_all
    train_targets = train_targets_all['target']
    slice_outputs_resampled = []
    
    resampled_ = resample_sliced_data_indices_by_probs(sliced_data_indices,
                                                       sliced_data_probs,
                                                       train_targets, temp)
    _, all_resampled_indices_ = resampled_
    for ix, indices in enumerate(all_resampled_indices_):
        for so_ix, slice_output in enumerate(slice_outputs):
            if ix == 0:
                slice_outputs_resampled.append([])
            slice_output_ = slice_output[ix][indices]
            slice_outputs_resampled[so_ix].append(slice_output_)
    # Make it official
    sliced_data_indices = slice_outputs_resampled[0]
    visualize_slice_stats(train_loader, slice_outputs_resampled)
    return sliced_data_indices


def visualize_slice_stats(dataloader, slice_outputs):
    sliced_data_indices, sliced_data_losses, sliced_data_correct, sliced_data_probs = slice_outputs
    try:
        train_indices = dataloader.sampler.indices
        train_targets_all = copy.deepcopy(dataloader.dataset.targets_all)
        for target_type, targets in train_targets_all.items():
            train_targets_all[target_type] = targets[train_indices]
    except Exception as e:
        print(e)
        train_targets_all = dataloader.dataset.targets_all
        
    for ix, indices in enumerate(sliced_data_indices):
        unique_spurious_counts = np.unique(train_targets_all['spurious'][indices],
                                           return_counts=True)
        unique_target_counts = np.unique(train_targets_all['target'][indices],
                                         return_counts=True)
        # Give info on spurious attrib of minority classes
        min_count_class = unique_target_counts[0][np.argmin(unique_target_counts[1])]
        min_count_indices = np.where(train_targets_all['target'][indices] == min_count_class)[0]
        min_count_target_vals = np.unique(train_targets_all['target'][indices][min_count_indices], return_counts=True)
        min_count_spurious_vals = np.unique(train_targets_all['spurious'][indices][min_count_indices], return_counts=True)        
        print(f"Min count class target values: {min_count_target_vals}")
        print(f"Min count class spurious values: {min_count_spurious_vals}")
        max_count_class = unique_target_counts[0][np.argmax(unique_target_counts[1])]
        max_count_indices = np.where(train_targets_all['target'][indices] == max_count_class)[0]
        max_count_target_vals = np.unique(train_targets_all['target'][indices][max_count_indices], return_counts=True)
        max_count_spurious_vals = np.unique(train_targets_all['spurious'][indices][max_count_indices], return_counts=True)        
        print(f"Max count class target values: {max_count_target_vals}")
        print(f"Max count class spurious values: {max_count_spurious_vals}")
        spurious_label_header = f'Spurious values in slice {ix + 1}: {unique_spurious_counts}'
        target_label_header = f'Target values in slice {ix + 1}: {unique_target_counts}'
        print(spurious_label_header)
        print_header(target_label_header, style='bottom')
        

# ------------------------------------------------------------------        
# Adding this from slice.py for easier reference, future refactoring
# ------------------------------------------------------------------
def compute_slice_indices(net, dataloader, criterion, 
                          batch_size, args, resample_by='class',
                          loss_factor=1., use_dataloader=False):
    """
    Use trained model to slice data given a dataloader

    Args:
    - net (torch.nn.Module): Pytorch neural network model
    - dataloader (torch.nn.utils.DataLoader): Pytorch data loader
    - criterion (torch.nn.Loss): Pytorch cross-entropy loss (with reduction='none')
    - batch_size (int): Batch size to compute slices over
    - args (argparse): Experiment arguments
    - resamble_by (str): How to resample, ['class', 'correct']
    Returns:
    - sliced_data_indices (int(np.array)[]): List of numpy arrays denoting indices of the dataloader.dataset
                                             corresponding to different slices
    """
    # First compute pseudolabels
    dataloader_ = dataloader if use_dataloader else None
    dataset = dataloader.dataset
    slice_outputs = compute_pseudolabels(net, dataset, 
                                         batch_size, args,  # Added this dataloader
                                         criterion, dataloader=dataloader_)
    pseudo_labels, outputs, correct, correct_spurious, losses = slice_outputs
    
    output_probabilities = torch.exp(outputs) / torch.exp(outputs).sum(dim=1).unsqueeze(dim=1)

    sliced_data_indices = []
    all_losses = []
    all_correct = []
    correct = correct.detach().cpu().numpy()
    all_probs = []
    for label in np.unique(pseudo_labels):
        group = np.where(pseudo_labels == label)[0]
        if args.weigh_slice_samples_by_loss:
            losses_per_group = losses[group]
        correct_by_group = correct[group]
        probs_by_group = output_probabilities[group]
        if args.subsample_labels is True or args.supersample_labels is True:
            group_vals = np.unique(dataloader.dataset.targets[group],
                                   return_counts=True)[1]
            sample_size = (np.min(group_vals) if args.subsample_labels is True
                           else np.max(group_vals))
            sampled_indices = []
            # These end up being the same
            if resample_by == 'class':
                target_values = dataloader.dataset.targets[group]
            elif resample_by == 'correct':
                target_values = correct_by_group
            # assert correct_by_group == dataloader.dataset.targets[group]
            print(f'> Resampling by {resample_by}...')
            for v in np.unique(target_values):
                group_indices = np.where(target_values == v)[0]
                if args.subsample_labels is True:
                    sampling_size = np.min([len(group_indices), sample_size])
                    replace = False
                    p = None
                elif args.supersample_labels is True:
                    sampling_size = np.max(
                        [0, sample_size - len(group_indices)])
                    sampled_indices.append(group_indices)
                    replace = True
                    if args.weigh_slice_samples_by_loss:
                        p = losses_per_group[group_indices] * loss_factor
                        p = (torch.exp(p) / torch.exp(p).sum()).numpy()
                    else:
                        p = None
                sampled_indices.append(np.random.choice(
                    group_indices, size=sampling_size, replace=replace, p=p))  # p = p
            sampled_indices = np.concatenate(sampled_indices)
            sorted_indices = np.arange(len(sampled_indices))
            if args.weigh_slice_samples_by_loss:
                all_losses.append(losses_per_group[sampled_indices][sorted_indices])
            sorted_indices = np.arange(len(sampled_indices))
            sliced_data_indices.append(group[sampled_indices][sorted_indices])
            all_correct.append(correct_by_group[sampled_indices][sorted_indices])
            all_probs.append(probs_by_group[sampled_indices][sorted_indices])
        else:
            if args.weigh_slice_samples_by_loss:
                sorted_indices = torch.argsort(losses_per_group, descending=True)
                all_losses.append(losses_per_group[sorted_indices])
            else:
                sorted_indices = np.arange(len(group))
            sliced_data_indices.append(group[sorted_indices])
            all_correct.append(correct_by_group[sorted_indices])
            all_probs.append(probs_by_group[sorted_indices])
    # Save GPU memory
    for p in net.parameters():
        p = p.detach().cpu() 
    net.to(torch.device('cpu')) 
    return sliced_data_indices, all_losses, all_correct, all_probs


def compute_pseudolabels(net, dataset, batch_size, args, criterion=None, 
                         dataloader=None):
    net.eval()
    if dataloader is None:
        new_loader = DataLoader(dataset, batch_size=batch_size,
                                shuffle=False, num_workers=args.num_workers)
    else:
        new_loader = dataloader
        dataset = dataloader.dataset
    all_outputs = []
    all_predicted = []
    all_correct = []
    all_correct_spurious = []
    all_losses = []
    net.to(args.device)

    with torch.no_grad():
        targets_s = dataset.targets_all['spurious']
        for batch_ix, data in enumerate(tqdm(new_loader)):
            inputs, labels, data_ix = data
            labels_spurious = torch.tensor(
                [targets_s[ix] for ix in data_ix]).to(args.device)

            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            all_outputs.append(outputs.detach().cpu())
            all_predicted.append(predicted.detach().cpu())
            if args.weigh_slice_samples_by_loss:
                assert criterion is not None, 'Need to specify criterion'
                loss = criterion(outputs, labels)
                all_losses.append(loss.detach().cpu())

            # Save correct
            correct = (predicted == labels).to(torch.device('cpu'))
            correct_spurious = (predicted == labels_spurious).to(torch.device('cpu'))
            all_correct.append(correct)
            all_correct_spurious.append(correct_spurious)
            
            inputs = inputs.to(torch.device('cpu'))
            labels = labels.to(torch.device('cpu'))
            outputs = outputs.to(torch.device('cpu'))
            predicted = predicted.to(torch.device('cpu'))

    pseudo_labels = torch.hstack(all_predicted)
    outputs = torch.vstack(all_outputs)
    correct = torch.hstack(all_correct)
    correct_spurious = torch.hstack(all_correct_spurious)
    if len(all_losses) > 0:
        all_losses = torch.hstack(all_losses)
    else:
        all_losses = None
    return pseudo_labels, outputs, correct, correct_spurious, all_losses
        
        
# ----------------------------------------
# Alternative slicing with representations
# ----------------------------------------
def compute_cluster_assignment(cluster_labels, dataloader):
    all_correct = []
    all_correct_by_datapoint = []
    all_targets = dataloader.dataset.targets_all['target']
    
    # This permutations thing is gross - not actually Hungarian here?
    cluster_label_permute = list(permutations(np.unique(cluster_labels)))
    for cluster_map in cluster_label_permute:
        preds = np.vectorize(cluster_map.__getitem__)(cluster_labels)
        all_targets
        correct = (preds == all_targets)
        all_correct.append(correct.sum())
        all_correct_by_datapoint.append(correct)
    all_correct = np.array(all_correct) / len(all_targets)
    
    # Find best assignment
    best_map = cluster_label_permute[np.argmax(all_correct)]
    cluster_labels = np.vectorize(best_map.__getitem__)(cluster_labels)
    cluster_correct = all_correct_by_datapoint[
        np.argmax(all_correct)].astype(int)
    return cluster_labels, cluster_correct


def compute_slice_indices_by_rep(model, dataloader,
                                 cluster_umap=True, 
                                 umap_components=2,
                                 cluster_method='kmeans',
                                 args=None,
                                 visualize=False,
                                 cmap='tab10'):
    embeddings, predictions = save_activations(model, 
                                               dataloader, 
                                               args)
    if cluster_umap:
        umap_ = umap.UMAP(random_state=args.seed, 
                      n_components=umap_components)
        X = umap_.fit_transform(embeddings)
    else:
        X = embeddings
    n_clusters = args.num_classes
    if cluster_method == 'kmeans':
        clusterer = KMeans(n_clusters=n_clusters,
                           random_state=args.seed,
                           n_init=10)
        cluster_labels = clusterer.fit_predict(X)
        means = clusterer.cluster_centers_
    elif cluster_method == 'gmm':
        clusterer = GaussianMixture(n_components=n_clusters,
                                    random_state=args.seed,
                                    n_init=10)
        cluster_labels = clusterer.fit_predict(X)
        means = clusterer.means_
    else:
        raise NotImplementedError
    # Match clustering labels to training set    
    cluster_labels, cluster_correct = compute_cluster_assignment(cluster_labels, 
                                                                 dataloader)
    sliced_data_indices = []
    sliced_data_correct = []
    sliced_data_losses = []  # Not actually losses, but distance from point to cluster mean
    for label in np.unique(cluster_labels):
        group = np.where(cluster_labels == label)[0]
        sliced_data_indices.append(group)
        sliced_data_correct.append(cluster_correct[group])
        center = means[label]
        l2_dist = np.linalg.norm(X[group] - center, axis=1)
        sliced_data_losses.append(l2_dist)
    if visualize:
        colors = np.array(cluster_labels).astype(int)
        num_colors = len(np.unique(colors))
        plt.scatter(X[:, 0], X[:, 1], c=colors, s=1.0,
                    cmap=plt.cm.get_cmap(cmap, num_colors))
        plt.colorbar(ticks=np.unique(colors))
        fpath = os.path.join(args.image_path,
                             f'umap-init_slice-cr-{args.experiment_name}.png')
        plt.savefig(fname=fpath, dpi=300, bbox_inches="tight")
        plt.close()
        print(f'Saved UMAP to {fpath}!')
        
        # Save based on other info too
        targets_all = dataloader.dataset.targets_all
        for target_type in ['target', 'spurious']:
            colors = np.array(targets_all[target_type]).astype(int)
            num_colors = len(np.unique(colors))
            plt.scatter(X[:, 0], X[:, 1], c=colors, s=1.0,
                        cmap=plt.cm.get_cmap(cmap, num_colors))
            plt.colorbar(ticks=np.unique(colors))
            t = f'{target_type[0]}{target_type[-1]}'
            fpath = os.path.join(args.image_path,
                                 f'umap-init_slice-{t}-{args.experiment_name}.png')
            plt.savefig(fname=fpath, dpi=300, bbox_inches="tight")
            print(f'Saved UMAP to {fpath}!')
            plt.close()
    return sliced_data_indices, sliced_data_correct, sliced_data_losses


def combine_data_indices(sliced_data_indices, sliced_data_correct):
    """
    If computing slices from both the ERM model's predictions and 
    representation clustering, use to consolidate into single list of slice indices
    Args:
    - sliced_data_indices (np.array[][]): List of list of sliced indices from ERM and representation clustering, 
                                          e.g. [sliced_indices_pred, sliced_indices_rep],
                                          where sliced_indices_pred = [indices_with_pred_val_1, ... indices_with_pred_val_N]
    - sliced_data_correct (np.array[][]): Same as above, but if the prediction / cluster assignment was correct
    Returns:
    - total_sliced_data_indices (np.array[]): List of combined data indices per slice
    - total_sliced_data_correct (np.array[]): List of combined per-data losses per slice
    """
    sliced_data_indices, sliced_data_indices_ = sliced_data_indices
    sliced_data_correct, sliced_data_correct_ = sliced_data_correct
    total_sliced_data_indices = [[i] for i in sliced_data_indices]
    total_sliced_data_correct = [[c] for c in sliced_data_correct]
    for slice_ix, indices in enumerate(sliced_data_indices_):
        incorrect_ix = np.where(sliced_data_correct_[slice_ix] == 0)[0]
        incorrect_ix_rep = np.where(total_sliced_data_correct[slice_ix][0] == 0)[0]
        incorrect_indices = []
        for i in indices[incorrect_ix]:
            if i not in total_sliced_data_indices[slice_ix][0][incorrect_ix_rep]:
                incorrect_indices.append(i)
        total_sliced_data_indices[slice_ix].append(np.array(incorrect_indices).astype(int))
        total_sliced_data_correct[slice_ix].append(np.zeros(len(incorrect_indices)))
        total_sliced_data_indices[slice_ix] = np.concatenate(total_sliced_data_indices[slice_ix])
        total_sliced_data_correct[slice_ix] = np.concatenate(total_sliced_data_correct[slice_ix])
    return total_sliced_data_indices, total_sliced_data_correct
    
