from torch.utils.data import Subset
from anonlibrary.loader.model import evaluate_model
from anonlibrary.loader.model import get_module
from anonlibrary.loader.model import mask_model
from anonlibrary.loader.model import module_out_shape
from tqdm.auto import tqdm
import os
import torch


class AccuracyTopK:
    def __init__(self, k):
        self.k = k

    def init(self):
        self.correct = 0
        self.n_examples = 0

    def update(self, y, y_out):
        batch_size = y.shape[0]

        _, topk = y_out.topk(self.k, dim=1)
        y = y.repeat(self.k, 1).t()
        correct = y == topk
        correct = correct.sum()
        self.correct += correct

        self.n_examples += batch_size

    def finalize(self):
        return self.correct / self.n_examples


def restrict_dataset(target_dataset, target_class):
    '''
    Given a dataset and a target class
    it returns a subset of the dataset
    with only the examples labelled
    with the provided class.

    Parameters
    ----------
    target_dataset: torch.utils.data.Dataset
    target_class: int

    Returns
    -------
    target_class_dataset: torch.utils.data.Dataset
    '''
    target_class_indices = [i for i, e in enumerate(
        target_dataset.targets) if e == target_class]

    target_class_dataset = Subset(target_dataset,
                                  target_class_indices)

    return target_class_dataset


def target_loss(units, model, dataset, metric, target_class=None,
                batch_size=32, mu=0.0, gpu=False, silent=True):
    '''
    Given a set of units it evaluates
    the performances of a given model
    with or without them.

    Parameters
    ----------
    units: dict of list
        Dictionary containing for
        each module the list of
        units that should be
        masked

    model: torch.nn.Module

    dataset: torch.utils.data.Dataset

    metric:
        Object to evaluate the accuracy
        of the actual output and the
        annotation

    target_class: int, optional
        Class to consider when filtering

    batch_size: int, optional

    mu: float, optional
        Value to substitute when masking
        units

    gpu: bool, optional
        Wheter to use the GPU or not

    Returns
    -------
    v_a: float
        Metric on the input model
    v_b: float
        Metric on the masked model
    diff: float
        Absolute value of the difference
        between v_a and v_b
    '''

    if target_class is not None:
        dataset = restrict_dataset(dataset, target_class)

    v_a = evaluate_model(model, dataset, metric,
                         batch_size=batch_size, gpu=gpu,
                         silent=silent)

    model_b = mask_model(units, model, mu, gpu)
    v_b = evaluate_model(model_b, dataset, metric,
                         batch_size=batch_size, gpu=gpu,
                         silent=silent)

    diff = torch.abs(v_a - v_b)

    return v_a, v_b, diff


def sample_loss(module_name, model, dataset, metric, target_class=None,
                batch_size=32, sample_rate=0.1, mu=0.0, cache='', gpu=False):
    '''
    Estimates the effect of removing
    random units from a given module
    by sampling

    Parameters
    ----------
    module_name: str
        Module whose importance
        is to be estimated

    model: torch.nn.Module

    dataset: torch.utils.data.Dataset

    metric:
        Object to evaluate the accuracy
        of the actual output and the
        annotation

    target_class: int, optional
        Class to consider when filtering

    batch_size: int, optional

    sample_rate: float, optional
        Fraction of the units to evaluate

    mu: float, optional
        Value to substitute when masking
        units

    gpu: bool, optional
        Wheter to use the GPU or not

    Returns
    -------
    units: torch.tensor
        IDs of teh units analyzed
    v_a: float
        Metric on the input model
    v_b: float
        Metric on the masked model
    diff: float
        Absolute value of the difference
        between v_a and v_b
    cache: str, optional
        Path of the folder in which to
        ensure persistency
    '''

    if target_class is not None:
        dataset = restrict_dataset(dataset, target_class)

    # Try to load from file
    if cache:
        units_fname = os.path.join(cache, 'drop_%s_%d_units.pt'
                                   % (module_name, target_class))
        va_fname = os.path.join(cache, 'drop_%s_%d_va.pt'
                                % (module_name, target_class))
        vb_fname = os.path.join(cache, 'drop_%s_%d_vb.pt'
                                % (module_name, target_class))

        try:
            units = torch.load(units_fname)
            v_a = torch.load(va_fname)
            v_b = torch.load(vb_fname)
            diff = torch.abs(v_a - v_b)
            return units, v_a, v_b, diff
        except FileNotFoundError:
            pass

    v_a = evaluate_model(model, dataset, metric,
                         batch_size=batch_size, gpu=gpu)

    module = get_module(model, module_name)
    out_shape = module_out_shape(model, module, gpu)
    n_units = out_shape[1]

    units = torch.randperm(n_units)[:int(n_units*sample_rate)]
    units, _ = torch.sort(units)

    v_b = torch.tensor([
        evaluate_model(mask_model({module_name: [unit]}, model, mu, gpu),
                       dataset, metric, batch_size=batch_size,
                       silent=True, gpu=gpu)
        for unit in tqdm(units)
        ])

    diff = torch.abs(v_a - v_b)

    # Ensure persistency
    if cache:
        torch.save(units, units_fname)
        torch.save(v_a, va_fname)
        torch.save(v_b, vb_fname)

    return units, v_a, v_b, diff
