from collections import OrderedDict
import hashlib
import io
from PIL import Image
import collections
import numpy as np
import os
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torchvision.utils as vutils

from torch.optim.optimizer import Optimizer

from torch.optim.optimizer import Optimizer
from typing import Optional


class Logger(object):
    def __init__(self, log_dir, log_name="log.txt", verbose=False):
        self.log_dir = log_dir
        self.verbose = verbose
        self.log_file = open("%s/%s" % (log_dir, log_name), "wt")

    def print(self, *objects):
        print(*objects, file=self.log_file, flush=True)
        if self.verbose:
            print(*objects)

    def __del__(self):
        self.log_file.close()


# Adapted from https://github.com/facebookresearch/DomainBed/blob/master/domainbed/lib/misc.py
def seed_hash(*args):
    """
    Derive an integer hash from all args, for use as a random seed.
    """
    args_str = str(args)
    return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31)


def flatten_config_dict(config_dict):
    for key, value in config_dict.items():
        if isinstance(value, dict):
            for inner_name, inner_value in flatten_config_dict(value):
                yield f"{key}/{inner_name}", inner_value
        else:
            yield key, value


def evaluate(model, loader, weights, device):
    model.eval()

    classes = []
    predictions = []
    ces = []

    samples = None

    for i, (x, y) in enumerate(loader):
        x = x.to(device)
        y = y.to(device)
        logits = model.predict(x)

        classes.extend(y.tolist())
        predictions.extend(torch.argmax(logits, dim=1).tolist())

        if samples is None:
            samples = x.detach().cpu()

    model.train()

    classes = torch.tensor(classes)
    classes_one_hot = torch.nn.functional.one_hot(classes)
    class_sizes = torch.sum(classes_one_hot, dim=0)
    predictions = torch.tensor(predictions)
    indicators = torch.eq(classes, predictions).float()

    results = OrderedDict()
    results["accuracy"] = torch.mean(indicators).item() * 100.0


    class_accuracies = (
        torch.sum(indicators[:, None] * classes_one_hot, dim=0) / class_sizes * 100.0
    )
    results["accuracy_class_avg"] = torch.mean(class_accuracies).item()
    return results, predictions.tolist(), classes.tolist(), samples


def batch_to_grid(img_batch, data_params=None):
    results = []
    img_batch = img_batch.detach().cpu()
    if data_params is not None:
        inv_norm_transform = data_params.get("inv_norm_transform", None)
        if inv_norm_transform is not None:
            img_batch = inv_norm_transform(img_batch)
    img_batch = torch.clamp(img_batch, 0.0, 1.0)
    samples = img_batch
    grid = vutils.make_grid(samples, nrow=8, normalize=True, range=(0.0, 1.0))
    np_image = np.transpose(grid.numpy(), [1, 2, 0])
    results.append(("batch", np_image))
    return results


# Adapted from https://github.com/timgaripov/swa/blob/4a2ddfdb2692eda91f2ac41533b62027976c605b/utils.py#L107
def update_bn(loader, model, device=None, sample_limit=50000):
    """Updates BN running statistics by feeding samples from the provided loader to the model."""
    was_training = model.training
    model.train()

    momenta = dict()
    track_stats = dict()
    for module in model.modules():
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            if module.training:
                module.running_mean = torch.zeros_like(module.running_mean)
                module.running_var = torch.ones_like(module.running_var)
                momenta[module] = module.momentum
                track_stats[module] = module.track_running_stats
                module.track_running_stats = True

    if not momenta:
        model.train(was_training)
        return

    for module in momenta.keys():
        module.momentum = None
        module.num_batches_tracked *= 0

    with torch.no_grad():
        num_samples = 0
        for x in loader:
            if isinstance(x, (list, tuple)):
                x = x[0]
            if device is not None:
                x = x.to(device)

            model(x)

            num_samples += x.size(0)
            if num_samples > sample_limit:
                break

    for bn_module in momenta.keys():
        bn_module.momentum = momenta[bn_module]
    for bn_module in track_stats.keys():
        bn_module.track_running_stats = track_stats[bn_module]
    model.train(was_training)


def extract_features(feature_extractor, loader, device=None):
    was_training = feature_extractor.training
    feature_extractor.eval()

    dim = feature_extractor.n_outputs
    features = torch.empty([0, dim])
    labels = torch.empty([0], dtype=torch.long)
    for batch in loader:
        x, y = batch
        if device is not None:
            x = x.to(device)

        with torch.no_grad():
            z = feature_extractor(x)

            features = torch.cat((features, z.cpu()), dim=0)
            labels = torch.cat((labels, y), dim=0)

    feature_extractor.train(was_training)

    return features, labels


class MidpointNormalize(matplotlib.colors.Normalize):
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        matplotlib.colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y))


def select_points(features, labels, max_num_points=500):
    np.random.seed(1)
    indices = np.random.permutation(np.arange(features.shape[0]))
    n = min(max_num_points, features.shape[0])

    return features[indices[:n]], labels[indices[:n]]


def make_scatter_plot(
    features_src,
    labels_src,
    features_trg,
    labels_trg,
    classes,
    max_num_points=500,
    src=True,
    trg=True,
):
    markers = ["o", "X", "^", "s"]
    palette = sns.color_palette()

    scatter_features_src, scatter_labels_src = select_points(
        features_src, labels_src, max_num_points=max_num_points
    )
    scatter_features_trg, scatter_labels_trg = select_points(
        features_trg, labels_trg, max_num_points=max_num_points
    )

    if src:
        for c in classes:
            p_src = scatter_features_src[scatter_labels_src == c]
            plt.scatter(
                p_src[:, 0],
                p_src[:, 1],
                marker=markers[c],
                color=palette[0],
                s=45,
                alpha=0.2,
                edgecolors="k",
                label="src (c={})".format(c),
            )

    if trg:
        for c in classes:
            p_trg = scatter_features_trg[scatter_labels_trg == c]
            plt.scatter(
                p_trg[:, 0],
                p_trg[:, 1],
                marker=markers[c],
                color=palette[3],
                s=45,
                alpha=0.2,
                edgecolors="k",
                label="trg (c={})".format(c),
            )


def make_cls_plot(meshgrid, classifier, device):
    palette = sns.color_palette()
    c_indices = [2, 4, 5, 7]

    X, Y = meshgrid
    M = np.hstack((X.reshape(-1, 1), Y.reshape(-1, 1)))
    with torch.no_grad():
        M_tensor = torch.tensor(M, dtype=torch.float32).to(device)
        C = classifier(M_tensor).cpu().numpy()

    for i in range(C.shape[1]):
        C_other_max = np.max(np.concatenate((C[:, :i], C[:, i + 1 :]), axis=1), axis=1)

        V = C[:, i]
        V = np.ma.masked_array(V, mask=V < C_other_max - 1e-4)

        V = V.reshape(X.shape)

        col = palette[c_indices[i]]
        cmap_f = sns.light_palette(col, reverse=False, as_cmap=True)
        cmap_c = sns.dark_palette(col, reverse=True, as_cmap=True)

        cf = plt.contourf(X, Y, V, levels=10, alpha=0.3, cmap=cmap_f)
        cax = plt.gca().inset_axes([1.06 + 0.2 * i, 0.0, 0.05, 1.0])
        plt.colorbar(cf, ax=plt.gca(), cax=cax)
        plt.contour(X, Y, V, levels=10, alpha=0.6, linewidths=1.0, cmap=cmap_c)


def make_disc_plot(meshgrid, discriminator, device):
    X, Y = meshgrid
    M = np.hstack((X.reshape(-1, 1), Y.reshape(-1, 1)))
    with torch.no_grad():
        M_tensor = torch.tensor(M, dtype=torch.float32).to(device)
        V = discriminator(M_tensor).cpu().numpy()
    V = V.reshape(X.shape)
    cmap_f = "seismic"
    cmap_c = sns.color_palette("icefire", as_cmap=True)

    cf = plt.contourf(
        X, Y, V, levels=10, alpha=0.2, cmap=cmap_f, norm=MidpointNormalize(midpoint=0.0)
    )
    cax = plt.gca().inset_axes([-0.2, 0.0, 0.05, 1.0])
    plt.colorbar(cf, ax=plt.gca(), cax=cax)
    cax.yaxis.tick_left()
    plt.contour(
        X,
        Y,
        V,
        levels=10,
        alpha=0.8,
        linewidths=1,
        cmap=cmap_c,
        norm=MidpointNormalize(midpoint=0.0),
    )


def get_lim(array, rel=0.05):
    mx = array.max()
    mn = array.min()
    d = max(mx - mn, 0.01)
    mn = mn - d * rel
    mx = mx + d * rel
    return mn, mx


def feature_plot(
    algorithm,
    features_src,
    labels_src,
    features_trg,
    labels_trg,
    num_classes,
    device,
    title=None,
):

    rows = 2 + num_classes
    fig, axes = plt.subplots(figsize=(14, rows * 5.4), nrows=rows, ncols=2)

    features_src = features_src.numpy()
    labels_src = labels_src.numpy()
    features_trg = features_trg.numpy()
    labels_trg = labels_trg.numpy()

    coords = np.concatenate((features_src, features_trg), axis=0)
    xlim = get_lim(coords[:, 0])
    ylim = get_lim(coords[:, 1])

    x_grid = np.linspace(*xlim, 50)
    y_grid = np.linspace(*ylim, 50)
    meshgrid = np.meshgrid(x_grid, y_grid)

    all_classes = list(range(num_classes))

    plt.sca(axes[0][0])
    if hasattr(algorithm, "discriminator"):
        make_disc_plot(meshgrid, algorithm.discriminator, device=device)
        make_scatter_plot(
            features_src, labels_src, features_trg, labels_trg, classes=all_classes
        )
        plt.title("Discriminator output", fontsize=18, y=1.02)
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)

    plt.sca(axes[0][1])
    make_cls_plot(meshgrid, algorithm.network.classifier, device=device)
    make_scatter_plot(
        features_src, labels_src, features_trg, labels_trg, classes=all_classes
    )
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.title("Classifier output", fontsize=18, y=1.02)

    plt.sca(axes[1][0])
    make_scatter_plot(
        features_src,
        labels_src,
        features_trg,
        labels_trg,
        src=True,
        trg=False,
        classes=all_classes,
    )
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.title("Source scatter", fontsize=18, y=1.02)

    plt.sca(axes[1][1])
    make_scatter_plot(
        features_src,
        labels_src,
        features_trg,
        labels_trg,
        src=False,
        trg=True,
        classes=all_classes,
    )
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.title("Target scatter", fontsize=18, y=1.02)

    for c in range(num_classes):
        plt.sca(axes[c + 2][0])
        make_scatter_plot(
            features_src,
            labels_src,
            features_trg,
            labels_trg,
            src=True,
            trg=False,
            classes=[c],
        )
        plt.xlim(xlim)
        plt.ylim(ylim)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        plt.title("Source scatter (c = {})".format(c), fontsize=18, y=1.02)

        plt.sca(axes[c + 2][1])
        make_scatter_plot(
            features_src,
            labels_src,
            features_trg,
            labels_trg,
            src=False,
            trg=True,
            classes=[c],
        )
        plt.xlim(xlim)
        plt.ylim(ylim)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        plt.title("Target scatter (c = {})".format(c), fontsize=18, y=1.02)

    fig.legend(
        *(axes[0][1].get_legend_handles_labels()),
        fontsize=18,
        ncol=2,
        loc="lower center",
        bbox_to_anchor=(0.45, 0.94),
    )

    if title is not None:
        plt.suptitle(title, fontsize=20, y=1.04)

    plt.tight_layout(w_pad=1.2, rect=(0.05, 0.02, 0.95, 0.95))

    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=120, bbox_inches="tight")
    buf.seek(0)
    image = Image.open(buf)
    np_image = np.array(image.convert("RGB"))
    plt.close(fig)

    return np_image


class StepwiseLR:
    """
    A lr_scheduler that update learning rate using the following schedule:

    .. math::
        \text{lr} = \text{init_lr} \times \text{lr_mult} \times (1+\gamma i)^{-p},

    where `i` is the iteration steps.

    Parameters:
        - **optimizer**: Optimizer
        - **init_lr** (float, optional): initial learning rate. Default: 0.01
        - **gamma** (float, optional): :math:`\gamma`. Default: 0.001
        - **decay_rate** (float, optional): :math:`p` . Default: 0.75
    """

    def __init__(
        self,
        optimizer: Optimizer,
        init_lr: Optional[float] = 0.001,
        gamma: Optional[float] = 0.001,
        decay_rate: Optional[float] = 0.75,
        # weight_decay: Optional[float] = 0.0005,
    ):
        self.init_lr = init_lr
        self.gamma = gamma
        self.decay_rate = decay_rate
        # self.weight_decay = weight_decay
        self.optimizer = optimizer
        self.iter_num = 0

    def get_lr(self) -> float:
        lr = self.init_lr * (1 + self.gamma * self.iter_num) ** (-self.decay_rate)
        return lr

    def get_last_lr(self) -> float:
        return self.last_lr

    def step(self):
        """Increase iteration number `i` by 1 and update learning rate in `optimizer`"""
        lr = self.get_lr()

        if self.optimizer:
            for param_group in self.optimizer.param_groups:
                if "lr_mult" not in param_group:
                    param_group["lr_mult"] = 1.0
                param_group["lr"] = lr * param_group["lr_mult"]
                # if "decay_mult" not in param_group:
                #     param_group["decay_mult"] = 1.0
                # param_group["weight_decay"] = (
                #     self.weight_decay * param_group["decay_mult"]
                # )
            self.last_lr = [
                param_group["lr"] for param_group in self.optimizer.param_groups
            ]

        self.iter_num += 1


def calculate_marginal(dataset):
    classes = None
    if hasattr(dataset, "attributes"):
        classes = dataset.attributes.get("classes", None)
    if classes is None:
        classes = [int(y) for _, y in dataset]
    counter = collections.Counter(classes)

    num_classes = len(counter)
    original_class_samples = np.array([counter[y] for y in range(num_classes)])

    return original_class_samples / sum(original_class_samples)


def default_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, "rb") as f:
        with Image.open(f) as img:
            return img.convert("RGB")


def make_dataset(image_list, labels=None):
    if labels:
        len_ = len(image_list)
        images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
    else:
        if len(image_list[0].split()) > 2:
            images = [
                (val.split()[0], np.array([int(la) for la in val.split()[1:]]))
                for val in image_list
            ]
        else:
            images = [(val.split()[0], int(val.split()[1])) for val in image_list]
    return images


class ImageList(object):
    """A generic data loader where the images are arranged in this way: ::
            root/dog/xxx.png
            root/dog/xxy.png
            root/dog/xxz.png
            root/cat/123.png
            root/cat/nsdf3.png
            root/cat/asd932_.png
    Args:
            root (string): Root directory path.
            transform (callable, optional): A function/transform that  takes in an PIL image
                    and returns a transformed version. E.g, ``transforms.RandomCrop``
            target_transform (callable, optional): A function/transform that takes in the
                    target and transforms it.
            loader (callable, optional): A function to load an image given its path.
     Attributes:
            classes (list): List of the class names.
            class_to_idx (dict): Dict with items (class_name, class_index).
            imgs (list): List of (image path, class_index) tuples
    """

    def __init__(
        self,
        image_list,
        root,
        transform=None,
        target_transform=None,
        loader=default_loader,
    ):
        imgs = make_dataset(image_list)

        if len(imgs) == 0:
            raise (RuntimeError("Found 0 images in subfolders of: " + root + "\n"))

        self.root = root
        self.data = np.array([os.path.join(self.root, img[0]) for img in imgs])
        self.labels = np.array([img[1] for img in imgs])
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        """
        Args:
                index (int): Index
        Returns:
                tuple: (image, target) where target is class_index of the target class.
        """
        path, target = self.data[index], self.labels[index]
        # path = os.path.join(self.root, path)
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        # return img, target, index
        return img, target

    def __len__(self):
        return len(self.data)
