import numpy as np
import utils
import torch
from sklearn.neighbors import LocalOutlierFactor
from tqdm import tqdm

class Metrics:
    def __init__(self, training_set, device, n_neighbors=10, n_jobs=-1):
        self.training_set, self.training_labels = training_set
        self.lof = LocalOutlierFactor(n_neighbors=20, novelty=True, n_jobs=-1)
        self.lof.fit(self.training_set)
        self.device = device

    def get_lof_score(self, sample, target):
        return self.lof.score_samples(sample)

    def get_cost(self, sample, unperturbed_input, metric='l1'):
        if metric == 'l1':
            return torch.sum(torch.abs(sample - unperturbed_input))
        elif metric == 'l2':
            return torch.sum(torch.square(sample - unperturbed_input))
        elif metric == 'linf':
            return torch.max(torch.abs(sample - unperturbed_input))
        else:
            raise ValueError('Invalid metric')


    def get_robustness_jpm(self, sample, model, target_cls, std_dev=1, num_local_samps=100):
        local_lips = model.get_local_lipschitz(sample)
        local_samples = torch.normal(mean=sample.float().unsqueeze(-1).expand(-1, num_local_samps), std=std_dev)
        diffs = abs(sample[:, None] - local_samples)
        outputs = model._forward_op(local_samples[:, None,  :, None].to(self.device))
        return torch.mean(outputs[target_cls, 0, :] - (local_lips.item() * diffs))


    def get_robustness_ours(self, model, cfx, original_label, cfx_target, use_multiclass, budget=2, max_iter=1000, stepsize=1e-3, dist_weight=0):
        cfx = cfx.to(self.device)
        real_budget = torch.max(model.data_range) * (budget/100)
        original_label = original_label.to(self.device)
        x = cfx.clone()
        x_cf = utils.attack(model, x, cfx, original_label, cfx_target, num_iterations=max_iter, stepsize=stepsize, dist_weight=0, use_multiclass=use_multiclass)
        x_cf = x_cf.clamp(x-real_budget, x+real_budget)
        x_cf = x_cf.clamp(0, 1.)
        return torch.tensor(float(torch.argmax(model.forward(x_cf)) == cfx_target), device='cpu')


    def get_SGDL_samples(self, model, target_range, number_of_samps=20, sgdl_step=2.0, sgdl_max_iter=500, sgdl_variance=1):
        samp_scale = torch.rand(number_of_samps)
        samples = model.data_range.type(torch.float64).flatten().unsqueeze(-1).expand((-1, number_of_samps))
        samples = samples * samp_scale[None, :]
        sgdl_samples = {}
        for target in (pbar := tqdm(target_range)):
            pbar.set_description("Computing SGDL Set")
            target = target.item()
            cur_samples = samples.clone().to(self.device)
            active_sgdl_step = sgdl_step
            for i in range(sgdl_max_iter):
                grad = utils.get_input_grad_(model, cur_samples.unsqueeze(1).unsqueeze(-1), target, self.device)[:, 0, :, 0]
                cur_samples = cur_samples + ((active_sgdl_step**2) * (grad/2)) + (
                            active_sgdl_step * torch.normal(mean=0, std=sgdl_variance, size=cur_samples.size(), device=self.device))
                cur_samples = cur_samples.clip(torch.min(model.data_range).to(self.device), torch.max(model.data_range).to(self.device))
                if i%50 == 0:
                    active_sgdl_step *= 0.9
            sgdl_samples[str(target)] = cur_samples
        return sgdl_samples


    def implausibility_soa_metric(self, cfx, target_label, number_of_samples=100, alternate_sample_set=None):
        if alternate_sample_set is None:
            inclass_samples = self.training_set[self.training_labels == target_label].to(cfx.device)
        else:
            inclass_samples = alternate_sample_set.to(cfx.device).T
        total_samples = inclass_samples.shape[0]
        if not total_samples < number_of_samples:
            used_samples = np.random.choice([i for i in range(total_samples)], size=number_of_samples, replace=False)
            inclass_samples = inclass_samples[used_samples]
        implausibility = 0
        for sample in inclass_samples:
            implausibility += self.get_cost(cfx, sample, metric='l2')
        return implausibility / min(number_of_samples, total_samples)
