'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math
import random
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.patches as mpatches

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import pickle
import matplotlib
import io
from collections import defaultdict, deque
import datetime


import torch.distributed as dist
matplotlib.use('Agg')


def random_weighted_sum(samples_outputs_reshaped):
    """
    Vectorized version of the random weighted sum dimensionality reduction function.

    Args:
        samples_outputs_reshaped: [num_pairs, resolution, num_classes] Output of sampled points.
        normalize_weights: Whether to normalize the weights.
        seed: Random seed for reproducibility.

    Returns:
        random_weighted_output_batch: [num_pairs, resolution, 1] The resulting weighted output.
    """
    num_pairs, resolution, num_classes = samples_outputs_reshaped.shape
    
    with torch.no_grad():
        random_ints = torch.randint(0, 2, (num_pairs, num_classes), device=samples_outputs_reshaped.device, dtype=torch.float32)
        random_weights = 2 * random_ints - 1  # [num_pairs, num_classes]
        
    random_weights = random_weights.detach()
    
    random_weights = random_weights.unsqueeze(2)  # [num_pairs, num_classes, 1]
    random_weighted_output_batch = torch.bmm(samples_outputs_reshaped, random_weights)  # [num_pairs, resolution, 1]
    return random_weighted_output_batch



def adjust_mask_ratio(epoch, args):
    start = 0.8
    end = 0.2

    return end + ((args.epochs - epoch) / args.epochs) * (start - end)

def label_for_samples(alpha_expanded, samples_outputs_reshaped, num_pairs, label1_batch, label2_batch, device):
    """
    Generate labels for mixed samples, using one-hot encoding at the endpoint positions 
    and original outputs at other positions.

    Args:
        alpha_expanded: Alpha values with shape $[num\_pairs, resolution, 1, 1, 1]$.
        samples_outputs_reshaped: Model outputs with shape $[num\_pairs, resolution, num\_classes]$.
        num_pairs: Number of sample pairs.
        label1_batch: Batch of the first labels with shape $[num\_pairs]$.
        label2_batch: Batch of the second labels with shape $[num\_pairs]$.
        device: Computing device (e.g., 'cpu' or 'cuda').

    Returns:
        samples_outputs_modified: The modified model outputs.
    """
    num_classes = samples_outputs_reshaped.size(2)
    
    alpha_firsts = alpha_expanded[:, 0, 0, 0, 0]
    alpha_lasts = alpha_expanded[:, -1, 0, 0, 0]
    
    one_hot_full = torch.zeros_like(samples_outputs_reshaped)
    
    batch_indices = torch.arange(num_pairs, device=device)
    
    # Set the one-hot encoding for the first endpoint of each pair 
    # (where α = α_first: set non-target positions to α_first and the label1 position to α_last)
    # 
    # First, initialize all class entries at the first endpoint to alpha_firsts
    one_hot_full[batch_indices, 0] = alpha_firsts.view(-1, 1).expand(-1, num_classes)
    
    # Update label1_batch positions with alpha_lasts
    one_hot_full[batch_indices, 0, label1_batch] = alpha_lasts

    

    one_hot_full[batch_indices, -1] = alpha_firsts.view(-1, 1).expand(-1, num_classes)
    one_hot_full[batch_indices, -1, label2_batch] = alpha_lasts
    mask = torch.zeros_like(samples_outputs_reshaped[:, :, :1]) 
    mask[:, 0] = 1.0    
    mask[:, -1] = 1.0   
    
    # Fusion: one-hot for endpoints, original outputs for others.
    samples_outputs_modified = (
        samples_outputs_reshaped * (1.0 - mask) + 
        one_hot_full * mask
    )
    
    return samples_outputs_modified


def label_for_samples_random(alpha_expanded, samples_outputs_reshaped, num_pairs, label1_batch, label2_batch, device, mask_ratio=0.2):
    """
    Randomly mask specific positions and replace their values with the results 
    of linear interpolation between labels.

    Args:
        alpha_expanded: Tensor of alpha values used for interpolation.
        samples_outputs_reshaped: The original output tensor from the model.
        num_pairs: Total number of sample pairs.
        label1_batch: Batch of the first set of labels.
        label2_batch: Batch of the second set of labels.
        device: The computing device (e.g., 'cpu' or 'cuda').
        mask_ratio: The ratio of positions to be randomly masked (float between $[0, 1]$).
    """
    alpha_first = alpha_expanded[0, 0, 0, 0, 0]
    alpha_last = alpha_expanded[0, -1, 0, 0, 0]
    num_classes = samples_outputs_reshaped.size(2)
    resolution = samples_outputs_reshaped.size(1)
    

    alpha_values = alpha_expanded[:, :, 0, 0, 0]  
    
    # Generate random mask positions for each pair
    batch_indices = torch.arange(num_pairs, device=device)
    
    # Calculate the number of points to mask for each pair
    num_points_to_mask = max(1, int(mask_ratio * resolution))
    
    # Generate different random indices for each pair
    mask_positions = torch.zeros((num_pairs, resolution), device=device, dtype=torch.bool)
    
    for i in range(num_pairs):
        # Generate random indices for the current pair (excluding endpoints, adjust if needed)
        random_indices = torch.randperm(resolution, device=device)[:num_points_to_mask]
        mask_positions[i, random_indices] = True
    
    # Expand mask dimension to match output shape [num_pairs, resolution, 1]
    mask_expanded = mask_positions.unsqueeze(-1).float()
    
    # Calculate linear interpolation labels
    # Convert labels to one-hot encoding
    label1_one_hot = torch.zeros(num_pairs, num_classes, device=device)
    label1_one_hot[batch_indices, label1_batch] = 1.0
    
    label2_one_hot = torch.zeros(num_pairs, num_classes, device=device)
    label2_one_hot[batch_indices, label2_batch] = 1.0
    
    # Expand dimensions for broadcasting
    label1_expanded = label1_one_hot.unsqueeze(1)  # [num_pairs, 1, num_classes]
    label2_expanded = label2_one_hot.unsqueeze(1)  # [num_pairs, 1, num_classes]
    alpha_expanded_for_interp = alpha_values.unsqueeze(-1)  # [num_pairs, resolution, 1]
    
    # Calculate linear interpolation: (1-alpha)*label1 + alpha*label2
    # Note: alpha here represents the position of the sample point along the entire path
    interpolated_labels = (1.0 - alpha_expanded_for_interp) * label1_expanded + alpha_expanded_for_interp * label2_expanded
    
    # Fusion: use interpolated labels at masked positions, keep original outputs elsewhere
    samples_outputs_modified = (
        samples_outputs_reshaped * (1.0 - mask_expanded) + 
        interpolated_labels * mask_expanded
    )

    
    return samples_outputs_modified

def pca(samples_outputs_reshaped, num_pairs, k: int = 1):
    """
    Parallelized PCA implementation
    
    Args:
        samples_outputs_reshaped: [num_pairs, resolution, num_classes]
        num_pairs: Total number of sample pairs (can actually be obtained from shape[0], kept for interface consistency)
        k: Number of principal components to retain, output dimension is k
    
    Returns:
        pca_output_batch: [num_pairs, resolution, k]
    """
    if k is None:
        k = 1
    k = int(k)
    if k <= 0:
        raise ValueError(f"k must be a positive int, got {k}")

    # 1. Batch Centering
    # Input shape: [N, R, C]
    # We need to compute the mean along dim=1 (resolution dimension) and keep the dimension for broadcasting
    mean = samples_outputs_reshaped.mean(dim=1, keepdim=True)  # [num_pairs, 1, num_classes]
    X_centered = samples_outputs_reshaped - mean               # [num_pairs, resolution, num_classes]

    # Limit k to the number of available principal components K=min(R, C)
    max_components = min(X_centered.shape[1], X_centered.shape[2])
    k_eff = min(k, max_components)

    # 2. Batch SVD
    # PyTorch's svd can automatically handle inputs of shape [Batch, M, N]
    with torch.no_grad():
        try:
            # U: [N, R, K], S: [N, K], Vh: [N, K, C] 
            # Here K = min(R, C), usually resolution > num_classes, so K=C
            U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False)

            # Take the first k principal components
            # Vh: [num_pairs, K, num_classes], take [:, :k_eff, :]
            pcs = Vh[:, :k_eff, :]  # [num_pairs, k_eff, num_classes]
            
        except RuntimeError:
            # Note: In parallel mode, if SVD fails for even one pair in the batch,
            # the entire batch will error out. This fallback strategy resets all.
            # For extremely high robustness, more complex masking is needed, but SVD rarely fails.
            num_classes = samples_outputs_reshaped.size(-1)
            device = samples_outputs_reshaped.device
            # Fallback: use standard basis vectors e1..ek (or fill with uniform vectors if insufficient)
            pcs = torch.zeros(num_pairs, k_eff, num_classes, device=device)
            diag_k = min(k_eff, num_classes)
            if diag_k > 0:
                eye = torch.eye(num_classes, device=device)[:diag_k]  # [diag_k, num_classes]
                pcs[:, :diag_k, :] = eye.unsqueeze(0).expand(num_pairs, -1, -1)
            if k_eff > diag_k:
                pcs[:, diag_k:, :] = 1.0 / num_classes

    # 3. Projection
    pcs = pcs.detach()

    # 3. Projection
    # X_centered: [N, R, C]
    # pcs:        [N, k_eff, C] -> transpose to [N, C, k_eff]
    pcs_t = pcs.transpose(1, 2)
    # [N, R, C] @ [N, C, k_eff] -> [N, R, k_eff]
    scores = torch.bmm(X_centered, pcs_t)

    # If k > k_eff (theoretically only happens when max_components < k), do not forcibly pad here to avoid introducing false dimensions.
    return scores

def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate with half-cycle cosine after warmup"""
    if epoch < args.warmup_epochs:
        lr = args.lr * epoch / args.warmup_epochs 
    else:
        lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
            (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
    for param_group in optimizer.param_groups:
        if "lr_scale" in param_group:
            param_group["lr"] = lr * param_group["lr_scale"]
        else:
            param_group["lr"] = lr
    return lr

def adjust_lambda_reg_linear(epoch, args):
    lambda_reg_from_args = args.lambda_reg
    if lambda_reg_from_args <= 0:
        return lambda_reg_from_args
    else:
        if epoch < args.warmup_epochs_for_lambda:
            lambda_reg = lambda_reg_from_args * epoch / args.warmup_epochs_for_lambda
        else:
            lambda_reg = lambda_reg_from_args
        return lambda_reg


def adjust_lambda_reg_sin(epoch, args):
    lambda_reg_from_args = args.lambda_reg
    if lambda_reg_from_args <= 0:
        return lambda_reg_from_args
    else:
        if epoch < args.warmup_epochs_for_lambda:
            # Sine ramp-up strategy: from 0 to π/2, corresponding values rise from 0 to 1
            progress = epoch / args.warmup_epochs_for_lambda
            lambda_reg = lambda_reg_from_args * math.sin(progress * math.pi / 2)
        else:
            lambda_reg = lambda_reg_from_args
        return lambda_reg


def _gaussian_kernel1d_local(kernel_size: int, sigma: float, device, dtype):
    # Generate a normalized 1D Gaussian kernel
    half = (kernel_size - 1) / 2.0
    x = torch.arange(kernel_size, device=device, dtype=dtype) - half
    kernel = torch.exp(-0.5 * (x / sigma) ** 2)
    kernel = kernel / kernel.sum()
    return kernel


def simple_lapsed_time(text, lapsed):
    hours, rem = divmod(lapsed, 3600)
    minutes, seconds = divmod(rem, 60)
    print(text+": {:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))


def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)



def get_loss_function(args):
    if args.criterion == '':
        criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
    elif 'kl' in args.criterion:
        def kl_loss(outputs, targets):
            return torch.nn.functional.kl_div(F.log_softmax(outputs, dim=1),F.softmax(targets, dim=1))
        criterion = kl_loss
    elif 'MSE' in args.criterion:
        criterion = torch.nn.MSELoss()
    elif 'mixup' in args.criterion:
        criterion = mixup_criterion

    return criterion


def get_scheduler(args, optimizer):
    if args.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
    elif args.scheduler == 'linear':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                      milestones=[args.epochs // 2.667, args.epochs // 1.6, args.epochs // 1.142], gamma=0.1)
    return scheduler


def get_random_images(trainset):
    imgs = []
    labels = []
    ids = []
    while len(imgs) < 3:
        idx = random.randint(0, len(trainset)-1)
        img, label = trainset[idx]
        if label not in labels:
            imgs.append(img)
            labels.append(label)
            ids.append(idx)

    return imgs, labels, ids

def get_noisy_images(dummy_imgs, dataset, net, device, from_scratch=False):
    #dm = torch.tensor(dataset.transform.transforms[-1].mean)[:, None, None]
    #ds = torch.tensor(dataset.transform.transforms[-1].std)[:, None, None]
    dm = (0.5 * torch.ones(dummy_imgs.shape[0])).unsqueeze(-1).unsqueeze(-1)
    ds = (0.25 * torch.ones(dummy_imgs.shape[0])).unsqueeze(-1).unsqueeze(-1)
    #imgs = torch.rand(dummy_imgs.shape)
    #imgs = (imgs - dm) / ds
    #imgs = imgs.to(device)
    new_imgs = []
    new_labels = []
    net.eval()
    with torch.no_grad():
        while len(new_labels) < dummy_imgs.shape[0]:
            imgs = torch.rand(dummy_imgs.shape)
            imgs = (imgs - dm) / ds
            imgs = imgs.to(device)
            outputs = net(imgs)
            _, labels = outputs.max(1)
            if from_scratch:
                new_imgs = [img.cpu() for img in imgs]
                new_labels = [label.cpu() for label in labels]
                break
            for i, label in enumerate(labels):
                ''' LF this takes too long for random training dynamics... 
                if label.cpu() not in new_labels:
                    new_imgs.append(imgs[i].cpu())
                    new_labels.append(label.cpu())
                '''
                new_imgs.append(imgs[i].cpu())
                new_labels.append(label.cpu())
    #new_imgs = [img.cpu() for img in imgs]
    #new_labels = [label.cpu() for label in labels]
    return new_imgs, new_labels


def _get_class_preds(planeset, preds, label, avoid_labels=None):
    x = []
    y = []
    vals = []
    for i, pred in enumerate(preds):
        val = torch.softmax(pred,0).max()
        class_pred = pred.argmax()
        if avoid_labels is None:
            if class_pred == label:
                x.append(planeset.coefs1[i].cpu().numpy())
                y.append(planeset.coefs2[i].cpu().numpy())
                vals.append(val.cpu().numpy())
        else:
            if class_pred not in avoid_labels:
                x.append(planeset.coefs1[i].cpu().numpy())
                y.append(planeset.coefs2[i].cpu().numpy())
                vals.append(val.cpu().numpy())
    return vals, x, y


def imscatter(x, y, image, ax=None, zoom=1):
    im = OffsetImage(image, zoom=zoom)
    x, y = np.atleast_1d(x, y)
    artists = []
    ab = AnnotationBbox(im, (x, y), xycoords='data', frameon=False)
    artists.append(ax.add_artist(ab))
    ax.update_datalim(np.column_stack([x, y]))
    ax.autoscale()
    return artists

def produce_plot(path, preds, planeloader, images, labels, trainloader, method='greys'):
    color_list = ['Reds', 'Blues', 'Greens']
    other_colors = ['Purples', 'Oranges', 'YlOrBr', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn']
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck')
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    import ipdb; ipdb.set_trace()
    for i, label in enumerate(labels):
        vals, x, y = _get_class_preds(planeloader.dataset, preds, label, avoid_labels=None)
        ax1.scatter(x, y, c=vals, cmap=color_list[i], label=f'class={label}')
    if method=='greys':
        vals, x, y = _get_class_preds(planeloader.dataset, preds, label, avoid_labels=labels)
        ax1.scatter(x, y, c=vals, cmap='Greys', label=f'class=other')

    elif method=='all':
        indcs = set(list(range(list(preds[0].shape)[0]))) - set(labels)
        for i, ind in enumerate(indcs):
            if i not in labels:
                vals, x, y = _get_class_preds(planeloader.dataset, preds, ind, avoid_labels=None)
                ax1.scatter(x, y, c=vals, cmap=other_colors[i], label=f'class={i}')


    ax1.legend

    coords = planeloader.dataset.coords

    dm = torch.tensor(trainloader.dataset.transform.transforms[-1].mean)[:, None, None]
    ds = torch.tensor(trainloader.dataset.transform.transforms[-1].std)[:, None, None]
    for i, image in enumerate(images):
        img = torch.clamp(image * ds + dm, 0, 1)
        img = img.cpu().numpy().transpose(1,2,0)
        coord = coords[i]
        imscatter(coord[0], coord[1], img, ax1)

    red_patch = mpatches.Patch(color='red', label=f'{classes[labels[0]]}')
    blue_patch = mpatches.Patch(color='blue', label=f'{classes[labels[1]]}')
    green_patch = mpatches.Patch(color='green', label=f'{classes[labels[2]]}')
    plt.legend(handles=[red_patch, blue_patch, green_patch], loc='upper center', bbox_to_anchor=(0.5, 1.05),
              ncol=3, fancybox=True, shadow=True)
    if path is not None:
        os.makedirs('images', exist_ok=True)
        plt.savefig(f'images/{path}.png')
    plt.close(fig)
    return

def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    if criterion == None:
        criterion = nn.CrossEntropyLoss()
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


def produce_plot_alt(path, preds, planeloader, images, labels, trainloader, epoch='best', temp=1.0):
    from matplotlib import cm
    from matplotlib.colors import LinearSegmentedColormap
    col_map = cm.get_cmap('tab10')
    cmaplist = [col_map(i) for i in range(col_map.N)]
    classes = ['airpl', 'autom', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']

    cmaplist = cmaplist[:len(classes)]
    col_map = LinearSegmentedColormap.from_list('custom_colormap', cmaplist, N=len(classes))
    fig, ax1 = plt.subplots()
    import torch.nn as nn
    preds = torch.stack((preds))
    preds = nn.Softmax(dim=1)(preds / temp)
    val = torch.max(preds,dim=1)[0].cpu().numpy()
    class_pred = torch.argmax(preds, dim=1).cpu().numpy()
    x = planeloader.dataset.coefs1.cpu().numpy()
    y = planeloader.dataset.coefs2.cpu().numpy()
    label_color_dict = dict(zip([*range(10)], cmaplist))

    color_idx = [label_color_dict[label] for label in class_pred]
    scatter = ax1.scatter(x, y, c=color_idx, alpha=val, s=0.1)
    markers = [plt.Line2D([0,0],[0,0],color=color, marker='o', linestyle='') for color in label_color_dict.values()]
    legend1 = plt.legend(markers, classes, numpoints=1,bbox_to_anchor=(1.01, 1))
    ax1.add_artist(legend1)
    coords = planeloader.dataset.coords

    dm = torch.tensor(trainloader.dataset.transform.transforms[-1].mean)[:, None, None]
    ds = torch.tensor(trainloader.dataset.transform.transforms[-1].std)[:, None, None]
    for i, image in enumerate(images):
        # import ipdb; ipdb.set_trace()
        img = torch.clamp(image * ds + dm, 0, 1)
        img = img.cpu().numpy().transpose(1,2,0)
        if img.shape[0] > 32:
            from PIL import Image
            img = img*255
            img = img.astype(np.uint8)
            img = Image.fromarray(img).resize(size=(32, 32))
            img = np.array(img)

        coord = coords[i]
        imscatter(coord[0], coord[1], img, ax1)

    red_patch = mpatches.Patch(color =cmaplist[labels[0]] , label=f'{classes[labels[0]]}')
    blue_patch = mpatches.Patch(color =cmaplist[labels[1]], label=f'{classes[labels[1]]}')
    green_patch = mpatches.Patch(color =cmaplist[labels[2]], label=f'{classes[labels[2]]}')
    plt.legend(handles=[red_patch, blue_patch, green_patch], loc='upper center', bbox_to_anchor=(0.5, 1.05),
              ncol=3, fancybox=True, shadow=True)
    plt.title(f'Epoch: {epoch}')
    if path is not None:
        img_dir = '/'.join([p for p in (path.split('/'))[:-1]])
        os.makedirs(img_dir, exist_ok=True)
        #os.makedirs(path.split, exist_ok=True)
        plt.savefig(f'{path}.png',bbox_extra_artists=(legend1,), bbox_inches='tight')
    plt.close(fig)
    return

def produce_plot_x(path, preds, planeloader, images, labels, trainloader, title='best', temp=1.0,true_labels = None):
    import seaborn as sns
    sns.set_style("whitegrid")
    paper_rc = {'lines.linewidth': 1, 'lines.markersize': 15,}                  
    sns.set_context("paper", rc = paper_rc,font_scale=1.5)  
    plt.rc("font", family="Times New Roman")
    from matplotlib import cm
    from matplotlib.colors import LinearSegmentedColormap
    col_map = cm.get_cmap('tab10')
    cmaplist = [col_map(i) for i in range(col_map.N)]
    classes = ['AIRPL', 'AUTO', 'BIRD', 'CAT', 'DEER',
                   'DOG', 'FROG', 'HORSE', 'SHIP', 'TRUCK']

    cmaplist = cmaplist[:len(classes)]
    col_map = LinearSegmentedColormap.from_list('custom_colormap', cmaplist, N=len(classes))
    fig, ax1  = plt.subplots()

    import torch.nn as nn
    preds = torch.stack((preds))
    preds = nn.Softmax(dim=1)(preds / temp)
    val = torch.max(preds,dim=1)[0].cpu().numpy()
    class_pred = torch.argmax(preds, dim=1).cpu().numpy()
    x = planeloader.dataset.coefs1.cpu().numpy()
    y = planeloader.dataset.coefs2.cpu().numpy()
    label_color_dict = dict(zip([*range(10)], cmaplist))

    color_idx = [label_color_dict[label] for label in class_pred]
    scatter = ax1.scatter(x, y, c=color_idx, alpha=val, s=0.1)
    markers = [plt.Line2D([0,0],[0,0],color=color, marker='o', linestyle='') for color in label_color_dict.values()]
    # legend1 = plt.legend(markers, classes, numpoints=1,bbox_to_anchor=(1.01, 1))
    # ax1.add_artist(legend1)
    coords = planeloader.dataset.coords

    dm = torch.tensor(trainloader.dataset.transform.transforms[-1].mean)[:, None, None]
    ds = torch.tensor(trainloader.dataset.transform.transforms[-1].std)[:, None, None]
    # import ipdb; ipdb.set_trace()
    markerd = {
        0: 'o',
        1 : '^',
        2 : 'X'
    }
    for i, image in enumerate(images):
        coord = coords[i]
        plt.scatter(coord[0], coord[1], s=150, c='red', marker=markerd[i])
    red_patch = mpatches.Patch(color =cmaplist[labels[0]] , label=f'{classes[labels[0]]}')
    blue_patch = mpatches.Patch(color =cmaplist[labels[1]], label=f'{classes[labels[1]]}')
    green_patch = mpatches.Patch(color =cmaplist[labels[2]], label=f'{classes[labels[2]]}')
    if true_labels is not None:
        p0 = mpatches.Patch(color =cmaplist[true_labels[0]] , label=f'{classes[true_labels[0]]}')
        p1 = mpatches.Patch(color =cmaplist[true_labels[1]] , label=f'{classes[true_labels[1]]}')
        p2 = mpatches.Patch(color =cmaplist[true_labels[2]] , label=f'{classes[true_labels[2]]}')
        ph = mpatches.Patch(color = 'white' , label="True Labels:",visible=False)

        # import ipdb; ipdb.set_trace()
        leg2 = plt.legend(handles=[ph,p0, p1, p2],  loc='upper center', bbox_to_anchor=(0.5, 1),
              ncol=4, fancybox=True, shadow=True,prop={'size': 10})
        ax1.add_artist(leg2)

    plt.legend(handles=[red_patch, blue_patch, green_patch], loc='lower center', bbox_to_anchor=(0.5, -0.1),
              ncol=3, fancybox=True, shadow=True,prop={'size': 18},handletextpad=0.2)
    plt.title(f'{title}',fontsize=20)
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax1.spines['left'].set_visible(False)
    ax1.spines['bottom'].set_visible(False)
    ax = plt.gca()
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)    
    # fig.tight_layout()
    # plt.gca().set_axis_off()
    # plt.subplots_adjust(top = 1.2, bottom = 0, right = 1, left = 0, 
    #             hspace = 0, wspace = 0)
    plt.margins(0,0)
    # plt.gca().xaxis.set_major_locator(plt.NullLocator())
    # plt.gca().yaxis.set_major_locator(plt.NullLocator())
    if path is not None:
        img_dir = '/'.join([p for p in (path.split('/'))[:-1]])
        os.makedirs(img_dir, exist_ok=True)
        #os.makedirs(path.split, exist_ok=True)
        if true_labels is not None :
            plt.savefig(f'{path}_x.png',bbox_extra_artists=(leg2,), bbox_inches='tight')
        else:
            plt.savefig(f'{path}_x.png', bbox_inches='tight')

    plt.close(fig)
    return

def produce_plot_sepleg(path, preds, planeloader, images, labels, trainloader, title='best', temp=0.01,true_labels = None):
    import seaborn as sns
    sns.set_style("whitegrid")
    paper_rc = {'lines.linewidth': 1, 'lines.markersize': 15,}                  
    sns.set_context("paper", rc = paper_rc,font_scale=1.5)  
    plt.rc("font", family="Times New Roman")
    from matplotlib import cm
    from matplotlib.colors import LinearSegmentedColormap
    col_map = cm.get_cmap('gist_rainbow')
    cmaplist = [col_map(i) for i in range(col_map.N)]
    classes = ['AIRPL', 'AUTO', 'BIRD', 'CAT', 'DEER',
                   'DOG', 'FROG', 'HORSE', 'SHIP', 'TRUCK']
    cmaplist = [cmaplist[45],cmaplist[30],cmaplist[170],cmaplist[150],cmaplist[65],cmaplist[245],cmaplist[0],cmaplist[220],cmaplist[180],cmaplist[90]]
    cmaplist[2] = (0.17254901960784313, 0.6274509803921569, 0.17254901960784313, 1.0)
    cmaplist[4] = (0.6509803921568628, 0.33725490196078434, 0.1568627450980392, 1.0)

    col_map = LinearSegmentedColormap.from_list('custom_colormap', cmaplist, N=len(classes))
    fig, ax1  = plt.subplots()

    import torch.nn as nn
    preds = torch.stack((preds))
    preds = nn.Softmax(dim=1)(preds / temp)
    val = torch.max(preds,dim=1)[0].cpu().numpy()
    class_pred = torch.argmax(preds, dim=1).cpu().numpy()
    x = planeloader.dataset.coefs1.cpu().numpy()
    y = planeloader.dataset.coefs2.cpu().numpy()
    label_color_dict = dict(zip([*range(10)], cmaplist))

    color_idx = [label_color_dict[label] for label in class_pred]
    scatter = ax1.scatter(x, y, c=color_idx, alpha=0.5, s=0.1)
    markers = [plt.Line2D([0,0],[0,0],color=color, marker='o', linestyle='') for color in label_color_dict.values()]

    coords = planeloader.dataset.coords

    dm = torch.tensor(trainloader.dataset.transform.transforms[-1].mean)[:, None, None]
    ds = torch.tensor(trainloader.dataset.transform.transforms[-1].std)[:, None, None]
    markerd = {
        0: 'o',
        1 : '^',
        2 : 'X'
    }
    for i, image in enumerate(images):
        coord = coords[i]
        plt.scatter(coord[0], coord[1], s=150, c='black', marker=markerd[i])

    labelinfo = {
        'labels' : [classes[i] for i in labels]
    }
    if true_labels is not None:
        labelinfo['true_labels'] = [classes[i] for i in true_labels] 


    # plt.title(f'{title}',fontsize=20)
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    ax1.spines['left'].set_visible(False)
    ax1.spines['bottom'].set_visible(False)
    ax = plt.gca()
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)    

    plt.margins(0,0)
    if path is not None:
        img_dir = '/'.join([p for p in (path.split('/'))[:-1]])
        os.makedirs(img_dir, exist_ok=True)
        plt.savefig(f'{path}_x.png', bbox_inches='tight')


    plt.close(fig)
    return

class AttackPGD(nn.Module):
    def __init__(self, basic_net, dataset, config=None, numsteps = None):
        super(AttackPGD, self).__init__()
        dm = torch.tensor(dataset.transform.transforms[-1].mean)[:, None, None].to('cuda')
        ds = torch.tensor(dataset.transform.transforms[-1].std)[:, None, None].to('cuda')
        if config is None:
            '''
            config = {
                        'epsilon': 8.0 / 255 / self.ds,
                        'num_steps': 20,
                        'step_size': 2.0 / 255 / self.ds,
                        'loss_func': 'xent',
                        'num_restarts': 1
                    }
            '''
            config = {
                'epsilon': 8.0/255.0,
                'num_steps': 20,
                'step_size': 2.0/255.0,
                'loss_func': 'xent',
                'num_restarts': 1,
                'dm': dm,
                'ds': ds
            }

        if numsteps is not None:
            config['num_steps'] = numsteps

        self.config = config
        self.basic_net = basic_net
        self.step_size = config['step_size']
        self.epsilon = config['epsilon']
        self.num_steps = config['num_steps']
        self.num_restarts = config['num_restarts']
        assert config['loss_func'] == 'xent', 'Only xent supported for now.'

    def forward(self, inputs, targets, targeted=False):
        best_attack = inputs.detach()
        best_loss = 0.0

        '''
        for j in range(max(self.num_restarts,1)):
            x = inputs.detach()
            if self.num_restarts > 0:
                x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
            for i in range(self.num_steps):
                x.requires_grad_()
                with torch.enable_grad():
                    loss = nn.functional.cross_entropy(self.basic_net(x), targets, size_average=False)
                grad = torch.autograd.grad(loss, [x])[0]
                x = x.detach() + self.step_size*torch.sign(grad.detach())
                x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
                x = torch.clamp(x, (0.0 - dm) / ds, (1.0 - dm) / ds)
            if nn.functional.cross_entropy(self.basic_net(x), targets, size_average=False) > best_loss:
                best_attack = x
        return self.basic_net(best_attack), best_attack
        '''

        if self.epsilon == 0:
            return self.basic_net(inputs)
        else:
            x = inputs.detach()
            x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon) / self.config['ds']
            for i in range(self.num_steps):
                x.requires_grad_()
                with torch.enable_grad():
                    if targeted:
                        loss = -nn.functional.cross_entropy(self.basic_net(x), targets, reduction='sum')
                    else:
                        loss = nn.functional.cross_entropy(self.basic_net(x), targets, reduction='sum')
                grad = torch.autograd.grad(loss, [x])[0]
                x = x.detach() + self.step_size / self.config['ds'] * torch.sign(grad.detach())
                x = torch.min(torch.max(x, inputs - self.epsilon / self.config['ds']), inputs + self.epsilon / self.config['ds'])
                x = torch.max(torch.min(x, (1 - self.config['dm']) / self.config['ds']), -self.config['dm'] / self.config['ds'])
                #x = torch.clamp(x, 0.0 - self.config['dm'] / self.config['ds'], 1.0 - self.config['dm'] / self.config['ds'])

        return self.basic_net(best_attack), x


def calculate_iou_plot(path, preds, planeloader, images, labels, trainloader, epoch='best', temp=1.0):
    # LF: Need to finish. Not sure how I'll go about doing this. Want to
    # save the plots as well as the iou scores
    from matplotlib import cm
    from matplotlib.colors import LinearSegmentedColormap
    col_map = cm.get_cmap('tab10')
    cmaplist = [col_map(i) for i in range(col_map.N)]
    classes = ['airpl', 'autom', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']

    cmaplist = cmaplist[:len(classes)]
    col_map = LinearSegmentedColormap.from_list('custom_colormap', cmaplist, N=len(classes))
    fig, ax1 = plt.subplots()
    import torch.nn as nn
    preds = torch.stack((preds))
    preds = nn.Softmax(dim=1)(preds / temp)
    val = torch.max(preds,dim=1)[0].cpu().numpy()
    class_pred = torch.argmax(preds, dim=1).cpu().numpy()
    x = planeloader.dataset.coefs1.cpu().numpy()
    y = planeloader.dataset.coefs2.cpu().numpy()
    label_color_dict = dict(zip([*range(10)], cmaplist))

    color_idx = [label_color_dict[label] for label in class_pred]
    scatter = ax1.scatter(x, y, c=color_idx, alpha=val, s=0.1)
    markers = [plt.Line2D([0,0],[0,0],color=color, marker='o', linestyle='') for color in label_color_dict.values()]
    legend1 = plt.legend(markers, classes, numpoints=1,bbox_to_anchor=(1.01, 1))
    ax1.add_artist(legend1)
    coords = planeloader.dataset.coords

    dm = torch.tensor(trainloader.dataset.transform.transforms[-1].mean)[:, None, None]
    ds = torch.tensor(trainloader.dataset.transform.transforms[-1].std)[:, None, None]
    for i, image in enumerate(images):
        img = torch.clamp(image * ds + dm, 0, 1)
        img = img.cpu().numpy().transpose(1,2,0)
        coord = coords[i]
        imscatter(coord[0], coord[1], img, ax1)

    red_patch = mpatches.Patch(color =cmaplist[labels[0]] , label=f'{classes[labels[0]]}')
    blue_patch = mpatches.Patch(color =cmaplist[labels[1]], label=f'{classes[labels[1]]}')
    green_patch = mpatches.Patch(color =cmaplist[labels[2]], label=f'{classes[labels[2]]}')
    plt.legend(handles=[red_patch, blue_patch, green_patch], loc='upper center', bbox_to_anchor=(0.5, 1.05),
              ncol=3, fancybox=True, shadow=True)
    if path is not None:
        img_dir = '/'.join([p for p in (path.split('/'))[:-1]])
        os.makedirs(img_dir, exist_ok=True)
        #os.makedirs(path.split, exist_ok=True)
        plt.savefig(f'{path}.png',bbox_extra_artists=(legend1,), bbox_inches='tight')
    plt.close(fig)
    return

def calculate_iou_no_plot(pred_arr, many_nets=False):
    pred_arr = [torch.stack(pred).argmax(1) for pred in pred_arr]
    if many_nets:
        ious = torch.zeros((len(pred_arr), len(pred_arr)))
    else:
        ious = []
    for i in range(len(pred_arr)):
        for j in range(i+1, len(pred_arr)):
            diff = pred_arr[i].shape[0] - (pred_arr[i] - pred_arr[j]).count_nonzero()
            if many_nets:
                ious[i,j] = diff/pred_arr[i].shape[0]
            else:
                ious.append(diff / pred_arr[i].shape[0])
    if many_nets:
        return ious
    else:
        return torch.mean(torch.stack(ious))


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        log_msg = [
            header,
            '[{0' + space_fmt + '}/{1}]',
            'eta: {eta}',
            '{meters}',
            'time: {time}',
            'data: {data}'
        ]
        if torch.cuda.is_available():
            log_msg.append('max mem: {memory:.0f}')
        log_msg = self.delimiter.join(log_msg)
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {} ({:.4f} s / it)'.format(
            header, total_time_str, total_time / len(iterable)))


def _load_checkpoint_for_ema(model_ema, checkpoint):
    """
    Workaround for ModelEma._load_checkpoint to accept an already-loaded object
    """
    mem_file = io.BytesIO()
    torch.save({'state_dict_ema':checkpoint}, mem_file)
    mem_file.seek(0)
    model_ema._load_checkpoint(mem_file)


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def init_distributed_mode(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = 'nccl'
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)
