from xmeta.utils.data import degrade_images
from xmeta.maml.maml import xfast_adapt
from xmeta.utils.seed import set_seed
from torch import nn
import numpy as np


LOSS = nn.CrossEntropyLoss(reduction='mean')


def explain_degraded_task(explainer, task, ratio=0.2, alpha=1., top_k=5, seed=42,
                          preprocess=(lambda x: x), shots=5, ways=5):
    learner = explainer.model.clone()
    set_seed(seed)
    task = degrade_images(task, ratio=ratio, alpha=alpha)
    task_feature = preprocess(task)
    result = xfast_adapt(task_feature, learner, loss=LOSS, shots=shots, ways=ways,
                         output_hessian=True)
    # explainer.set_trg_param_matrix(hessian=result['train']['hessian'],
    #                                params=list(learner.parameters()))
    idxes, scores = explainer.explain(y=-result['evaluation']['error'], top_k=top_k)
    acc = result['evaluation']['accuracy'].detach().to('cpu').numpy()
    error = result['evaluation']['error'].detach().to('cpu').numpy()
    return idxes, scores, acc, error


def explain_degraded_tasks(explainer, taskset, task_idx, top_k=None, param_dicts=[],
                           preprocess=(lambda x: x)):
    if top_k is None:
        top_k = len(taskset)
    ranks = []
    scores = []
    accuracies = []
    errors = []
    for param_dict in param_dicts:
        idxes, task_scores, acc, error = explain_degraded_task(
            explainer, taskset[task_idx], top_k=top_k, preprocess=preprocess,
            **param_dict)
        rank = idxes.index(task_idx)
        score = task_scores[rank]
        ranks.append(rank)
        scores.append(score)
        accuracies.append(acc)
        errors.append(error)
    return ranks, scores, accuracies, errors
    

def loop_alpha(explainer, taskset, task_idx, ratio=1.,
               alphas=list(np.linspace(0, 1, 11)), preprocess=(lambda x: x)):
    ratios = [ratio] * len(alphas)
    param_dicts = [{'ratio': ratio, 'alpha': alpha}
                   for ratio, alpha in zip(ratios, alphas)]
    ranks, scores, accuracies, errors = explain_degraded_tasks(
        explainer, taskset, task_idx=task_idx, param_dicts=param_dicts,
        preprocess=preprocess)
    return ranks, scores, accuracies, errors


def loop_ratio(explainer, taskset, task_idx,
               ratios=list(np.linspace(0, 1, 11)), alpha=1.,
               preprocess=(lambda x: x)):
    alphas = [alpha] * len(ratios)
    param_dicts = [{'ratio': ratio, 'alpha': alpha}
                   for ratio, alpha in zip(ratios, alphas)]
    ranks, scores, accuracies, errors = explain_degraded_tasks(
        explainer, taskset, task_idx=task_idx, param_dicts=param_dicts,
        preprocess=preprocess)
    return ranks, scores, accuracies, errors
