import math
import os
from os import makedirs
from os.path import join

import numpy as np
import torch
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image

import data
import losses
from density import GaussianDensitySklearn, GaussianDensityTorch


class Evaluator:
    def __init__(self, root, dataset, obj_type, ano_type, device,
                 density='torch', batch_size=32, syn_args=None):
        self.dataset = dataset
        self.obj_type = obj_type
        self.ano_type = ano_type
        self.device = device
        self.batch_size = batch_size
        self.normalize = data.Normalize(dataset)

        if density == 'torch':
            self.density = GaussianDensityTorch()
        elif density == 'sklearn':
            self.density = GaussianDensitySklearn()
        else:
            raise ValueError()

        if root is None:
            self.trn_loader = None
            self.test_loader = None
        else:
            trn_data = data.load_data(root, dataset, obj_type, ano_type,
                                      mode='train', syn_args=syn_args)
            test_data = data.load_data(root, dataset, obj_type, ano_type,
                                       mode='test', syn_args=syn_args)
            self.trn_loader = DataLoader(trn_data, batch_size, shuffle=True)
            self.test_loader = DataLoader(test_data, batch_size, shuffle=False)

        # Cached embeddings and scores
        self.trn_emb = None
        self.trn_labels = None
        self.trn_scores = None
        self.test_emb = None
        self.test_labels = None
        self.test_scores = None

    def _get_training_embeddings(self, model, trn_size):
        emb_list = []
        for x, _ in self.trn_loader:
            x = x.to(self.device)
            embed, logit = model(self.normalize(x))
            emb_list.append(embed)
        return torch.cat(emb_list)[:trn_size]

    def _get_augmented_embeddings(self, model, aug_func, trn_size, ratio):
        multiplier = ratio / (1 - ratio)
        num_runs = math.ceil(multiplier)
        out_size = math.ceil(multiplier * trn_size)

        emb_list = []
        to_break = False
        for _ in range(num_runs):
            for x, _ in self.trn_loader:
                if to_break:
                    break
                elif len(x) >= out_size:
                    x = x[:out_size]
                    to_break = True
                else:
                    out_size -= len(x)
                x = x.to(self.device)
                aug_x = aug_func(x)
                embed, logit = model(self.normalize(aug_x))
                emb_list.append(embed)
        return torch.cat(emb_list)

    def _get_test_embeddings(self, model):
        emb_list, label_list = [], []
        for x, y in self.test_loader:
            x = x.to(self.device)
            embed, logit = model(self.normalize(x))
            emb_list.append(embed)
            label_list.append(y)
        test_emb = torch.cat(emb_list)
        test_labels = torch.cat(label_list).to(torch.long)
        return test_emb, test_labels

    def run_model(self, model, aug_func, sample_size=None, anom_ratio=False):
        assert not model.training
        test_emb, test_labels = self._get_test_embeddings(model)

        if sample_size is None:
            trn_size = len(self.trn_loader.dataset)
        else:
            trn_size = sample_size

        trn_emb_clean = self._get_training_embeddings(model, trn_size)
        trn_emb_augment = self._get_augmented_embeddings(
            model, aug_func, trn_size, anom_ratio
        )

        trn_emb = torch.cat([trn_emb_clean, trn_emb_augment])
        trn_labels = torch.tensor(
            data=[0] * len(trn_emb_clean) + [1] * len(trn_emb_augment),
            dtype=torch.int64
        )

        with torch.no_grad():
            self.density.fit(trn_emb_clean.cpu())
            trn_scores = self.density.predict(trn_emb.cpu())
            test_scores = self.density.predict(test_emb.cpu())

        self.trn_emb = trn_emb
        self.trn_labels = trn_labels
        self.trn_scores = trn_scores
        self.test_emb = test_emb
        self.test_labels = test_labels
        self.test_scores = test_scores

    @torch.no_grad()
    def save_embeddings(self, path):
        def save(tensor, name):
            np.save(join(path, name), tensor.cpu().numpy())

        makedirs(path, exist_ok=True)
        save(self.trn_emb, 'trn_embeddings')
        save(self.trn_labels, 'trn_labels')
        save(self.trn_scores, 'trn_scores')
        save(self.test_emb, 'test_embeddings')
        save(self.test_labels, 'test_labels')
        save(self.test_scores, 'test_scores')

    def load_embeddings(self, path):
        def load(file):
            return torch.from_numpy(np.load(join(path, file)))

        self.trn_emb = load('trn_embeddings.npy')
        self.trn_labels = load('trn_labels.npy')
        self.trn_scores = load('trn_scores.npy')
        self.test_emb = load('test_embeddings.npy')
        self.test_labels = load('test_labels.npy')
        self.test_scores = load('test_scores.npy')

    def to_validation_loss(self, name, normalize=True):
        trn_emb_n = self.trn_emb[self.trn_labels == 0]
        trn_emb_a = self.trn_emb[self.trn_labels == 1]
        test_emb = self.test_emb

        triple = trn_emb_n, trn_emb_a, test_emb
        if normalize:
            triple = losses.EmbNormalizer()(*triple)

        if name == 'mean':
            return losses.MeanLoss(metric='euclidean')(*triple)
        elif name == 'mmd':
            return losses.BaseLoss(metric='mmd')
        elif name in ['random', 'fixed']:
            return torch.tensor(0, dtype=torch.float32, device=test_emb.device)
        else:
            raise ValueError(name)

    def measure_auc(self):
        return roc_auc_score(
            self.test_labels.cpu().numpy(),
            self.test_scores.cpu().numpy()
        )

    def to_variance(self):
        return np.std(self.test_scores.cpu().numpy())

    @torch.no_grad()
    def get_scores(self):
        trn_scores = self.trn_scores.cpu().numpy()
        test_scores = self.test_scores.cpu().numpy()
        return trn_scores, test_scores

    @torch.no_grad()
    def visualize_images(self, aug_func, path_out, num_images=10):
        ll = []
        for x, y in self.trn_loader:
            ll.append(aug_func(x.to(self.device)))
        aug_out = torch.cat(ll)
        ll = []
        for x, y in self.test_loader:
            ll.append(x[y != 0].to(self.device))
        ano_out = torch.cat(ll)
        out = torch.cat([aug_out[:num_images], ano_out[:num_images]])

        os.makedirs(os.path.dirname(path_out), exist_ok=True)
        save_image(make_grid(out, nrow=num_images), path_out)
