from xmeta.utils.evaluation import accuracy
import numpy as np
import torch
from torch import nn
from xmeta.utils.higher_grad import higher_grad
from xmeta.utils.tensor import tensor2numpy
from xmeta.utils.opa import CrossEntropyHessian, inverse_psdmat
from xmeta.utils.gd_inverse import dot_generalized_inv_gd, dot_inv_gd, repeat_dot_inv_gd
from xmeta.explainer.explainer import ExplainerBase, ExplainerOPA
from xmeta.utils.seed import set_seed
from xmeta.utils.data\
    import ImpureTasksets, get_tasksets, pollute_tasks
from xmeta.utils.experiment import compare_tasksets
from xmeta.utils.sift import SiftFeature
from xmeta.utils.preprocess import TensorImageFFT
import pandas as pd
from tqdm import tqdm
import pickle
import os
from torchsummary import summary
import learn2learn as l2l

TASK_INTERVAL = 512


def setup_experiment(k=None,
                     ways=5,
                     shots=5,
                     num_tasks=None,
                     num_test_tasks=None,
                     experiment_dir=None,
                     explainer_path=None,
                     sift_centroids_path=None,
                     fft_crop_size=None,
                     impurity_dict_path=None,
                     train_mask_tasks=None,
                     train_noise_tasks=None,
                     train_shuffle_tasks=None,
                     train_dark_tasks=None,
                     train_recolor_tasks=None,
                     test_recolor_tasks=None,
                     train_bgr_tasks=None,
                     test_bgr_tasks=None,
                     seed=42,
                     score_shuffling=False,
                     dataset='mini-imagenet',
                     device=torch.device('cpu'),
                     ):
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)
        print(f'created {experiment_dir}')
    else:
        print(f'found {experiment_dir}')

    if num_test_tasks is None:
        num_test_tasks = num_tasks
    
    # tasksets
    _tasksets = get_tasksets(seed=seed,
                             name=dataset,
                             train_ways=ways,
                             train_samples=2 * shots,
                             test_ways=ways,
                             test_samples=2 * shots,
                             num_tasks=num_tasks,
                             root='~/data',
                             )

    # load explainer
    if explainer_path is not None:
        explainer = load_explainer(path=explainer_path, ways=ways,
                                   device=device, dataset=dataset)
        maml = explainer.model

    else:
        explainer = None
        maml = None
    
    set_seed(seed)
    if sift_centroids_path is not None:
        feature = SiftFeature(k=k, name='mifeature', use_cache=True,
                              pkl_path=sift_centroids_path)
    elif fft_crop_size is not None:
        feature = TensorImageFFT(crop_shape=(fft_crop_size, fft_crop_size))
    else:
        def feature(x):
            return x
    if impurity_dict_path is None:
        impurity_dict_path = os.path.join(experiment_dir, "index_dict.pkl")
    if os.path.exists(impurity_dict_path):
        with open(impurity_dict_path, 'rb') as f:
            impurity_dict = pickle.load(f)
        print(f'loaded {impurity_dict}')
        if train_mask_tasks is None:
            train_mask_tasks = len(impurity_dict['train_mask_tasks'])
        if train_noise_tasks is None:
            train_noise_tasks = len(impurity_dict['train_noise_tasks'])
        if train_shuffle_tasks is None:
            train_shuffle_tasks = len(impurity_dict['train_shuffle_tasks'])
        if train_dark_tasks is None:
            train_dark_tasks = len(impurity_dict['train_dark_tasks'])
        if train_recolor_tasks is None:
            train_recolor_tasks = len(impurity_dict['train_recolor_tasks'])
        if train_bgr_tasks is None:
            train_bgr_tasks = len(impurity_dict['train_bgr_tasks'])
    else:
        impurity_dict = {'train_mask_tasks': [],
                         'train_noise_tasks': [],
                         'train_shuffle_tasks': [],
                         'train_dark_tasks': [],
                         'train_recolor_tasks': [],
                         'train_bgr_tasks': [],
                         }
 
    if (train_mask_tasks is not None) or\
            (train_noise_tasks is not None) or\
            (train_shuffle_tasks is not None) or\
            (train_dark_tasks is not None) or\
            (train_recolor_tasks is not None) or\
            (train_bgr_tasks is not None):
        tasksets = ImpureTasksets(_tasksets,
                                  num_tasks=num_tasks, ways=ways, shots=shots,
                                  train_mask_tasks=train_mask_tasks,
                                  train_noise_tasks=train_noise_tasks,
                                  train_shuffle_tasks=train_shuffle_tasks,
                                  train_dark_tasks=train_dark_tasks,
                                  train_recolor_tasks=train_recolor_tasks,
                                  train_bgr_tasks=train_bgr_tasks,
                                  )
        impurity_dict = tasksets.impurity_dict
        if score_shuffling and train_shuffle_tasks > 0:
            impurity_dict['num_correct_labels'] = compare_tasksets(
                tasksets.train, _tasksets.train, idxes=impurity_dict['train_shuffle_tasks'])
    else:
        tasksets = _tasksets

    tasks_train = tasksets.train
    tasks_test = tasksets.test
    if test_recolor_tasks is not None:
        tasks_test = pollute_tasks(tasks_test, num_tasks=num_test_tasks,
                                   test_recolor_tasks=test_recolor_tasks)
    if test_bgr_tasks is not None:
        tasks_test = pollute_tasks(tasks_test, num_tasks=num_test_tasks,
                                   test_bgr_tasks=test_bgr_tasks)
    
    return tasks_train, tasks_test, explainer, maml, feature, impurity_dict


def fast_adapt(batch, learner, loss,
               shots, ways):
    data, labels = batch
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots * ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels =\
        data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels =\
        data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    train_error = loss(learner(adaptation_data), adaptation_labels)
    learner.adapt(train_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    valid_error = loss(predictions, evaluation_labels)
    valid_accuracy = accuracy(predictions, evaluation_labels)
    return valid_error, valid_accuracy


def xfast_adapt(batch,
                learner, loss, shots, ways, src_param_matrix=None,
                output_hessian=False, output_gradient=False, output_tensor=True):
    data, labels = batch
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots * ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels =\
        data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels =\
        data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    train_pred = learner(adaptation_data)
    train_error = loss(train_pred, adaptation_labels)
    train_acc = accuracy(train_pred, adaptation_labels)
    
    if output_hessian:
        params = list(learner.parameters())
        h = higher_grad(train_error, params, params).detach().to('cpu').numpy()
    else:
        h = None
    
    if output_gradient:
        params = list(learner.parameters())
        v = higher_grad(train_error, params)
    else:
        v = None
    
    if src_param_matrix is not None:
        params = list(learner.parameters())
        v = higher_grad(train_error, params)
        h_dot_src_mat = []
        for ii in range(0, src_param_matrix.shape[1], TASK_INTERVAL):
            src_mat = torch.Tensor(src_param_matrix[:, ii: ii + TASK_INTERVAL]
                                   ).to(v.device)
            v_dot_src_mat = torch.tensordot(v, src_mat, dims=1)
            h_dot_src_mat.append(
                higher_grad(v_dot_src_mat, params, create_graph=False)
                .to('cpu').numpy().T)
        h_dot_src_mat = np.hstack(h_dot_src_mat)
        assert src_param_matrix.shape == h_dot_src_mat.shape
        task_sensitity = src_param_matrix - learner.lr * h_dot_src_mat
    else:
        task_sensitity = None
    learner.adapt(train_error)
    
    adapt_pred = learner(adaptation_data)
    adapt_error = loss(adapt_pred, adaptation_labels)
    adapt_acc = accuracy(adapt_pred, adaptation_labels)

    # Evaluate the adapted model
    valid_pred = learner(evaluation_data)
    valid_error = loss(valid_pred, evaluation_labels)
    valid_acc = accuracy(valid_pred, evaluation_labels)
    
    if not output_tensor:
        train_error = train_error.detach().to('cpu').numpy()
        train_pred = train_pred.detach().to('cpu').numpy()
        train_acc = train_acc.detach().to('cpu').numpy()
        adapt_error = adapt_error.detach().to('cpu').numpy()
        adapt_pred = adapt_pred.detach().to('cpu').numpy()
        adapt_acc = adapt_acc.detach().to('cpu').numpy()
        valid_error = valid_error.detach().to('cpu').numpy()
        valid_pred = valid_pred.detach().to('cpu').numpy()
        valid_acc = valid_acc.detach().to('cpu').numpy()
        adaptation_indices = adaptation_indices.detach().to('cpu').numpy()
        evaluation_indices = evaluation_indices.detach().to('cpu').numpy()
        v = v.to('cpu').numpy()

    return {'train': {'error': train_error, 'prediction': train_pred,
                      'accuracy': train_acc, 'hessian': h, 'gradient': v,
                      'sensitivity': task_sensitity},
            'adaptation': {'error': adapt_error, 'prediction': adapt_pred,
                           'accuracy': adapt_acc},
            'evaluation': {'error': valid_error, 'prediction': valid_pred,
                           'accuracy': valid_acc},
            'adaptation_indices': adaptation_indices,
            'evaluation_indices': evaluation_indices
            }

 
def zero_shot(batch, learner, loss, evaluation_indices, output_tensor=True):
    data, labels = batch
    evaluation_data, evaluation_labels =\
        data[evaluation_indices], labels[evaluation_indices]

    # Evaluate the model
    valid_pred = learner(evaluation_data)
    valid_error = loss(valid_pred, evaluation_labels)
    valid_acc = accuracy(valid_pred, evaluation_labels)
    if not output_tensor:
        valid_error = valid_error.detach().to('cpu').numpy()
        valid_pred = valid_pred.detach().to('cpu').numpy()
        valid_acc = valid_acc.detach().to('cpu').numpy()

    return {'zero_shot': {'error': valid_error, 'prediction': valid_pred,
                          'accuracy': valid_acc},
            'evaluation_indices': evaluation_indices
            }


def meta_test(model, tasksets, preprocess=lambda x: x,
              loss=nn.CrossEntropyLoss(reduction='mean'),
              shots=5, ways=5, num_test_tasks: int = -1,
              num_test_iterations: int = 1024):
    n_iter = min(num_test_tasks, num_test_iterations)\
        if num_test_tasks > 0 else num_test_iterations
    
    max_acc = -1.
    min_acc = 10.
    meta_test_error = 0.0
    meta_test_accuracy = 0.0
    for ii in range(n_iter):
        learner = model.clone()
        if num_test_tasks > 0:
            batch_test = tasksets.test[ii]
        else:
            batch_test = tasksets.test.sample()
        batch_test = preprocess(batch_test)
        evaluation_error, evaluation_accuracy = fast_adapt(batch_test,
                                                           learner,
                                                           loss,
                                                           shots,
                                                           ways,
                                                           )
        max_acc = max(max_acc, evaluation_accuracy.item())
        min_acc = min(min_acc, evaluation_accuracy.item())
        meta_test_error += evaluation_error.item()
        meta_test_accuracy += evaluation_accuracy.item()
    
    meta_test_error = meta_test_error / n_iter
    meta_test_accuracy = meta_test_accuracy / n_iter

    print('Meta Test Error', meta_test_error)
    print('Meta Test Accuracy', meta_test_accuracy)
    print('Min Test Accuracy', min_acc)
    print('Max Test Accuracy', max_acc)

    return meta_test_error, meta_test_accuracy


def sum_scalars(scalars: list):
    return torch.cat([x.unsqueeze(0) for x in scalars]).sum()


def save_explainer(explainer,
                   prefix: str = None,
                   postfix: str = None,
                   model_path: str = None,
                   ):
    pkl_name = explainer.name
    if prefix is not None:
        pkl_name = prefix + '_' + pkl_name
    if postfix is not None:
        pkl_name = pkl_name + '_' + postfix
    pkl_name = pkl_name + '.pkl'

    if model_path is not None:
        model = explainer.model
        explainer.model = model_path
    
    path = os.path.join(explainer.savedir, pkl_name)
    with open(path, 'wb') as f:
        with open(path, 'wb') as f:
            pickle.dump(explainer, f)
            print(f'saved {path}')
    
    explainer.model = model


def load_explainer(path: str, ways: int = 5,
                   device=torch.device('cpu'),
                   dataset: str = 'mini-imagenet'):
    with open(path, 'rb') as f:
        explainer = pickle.load(f)
        print(f'loaded {path}')
    if isinstance(explainer.model, str):
        assert explainer.model[-4:] == '.pth'
        if dataset == 'omniglot':
            model = l2l.vision.models.OmniglotCNN(ways)
            summary(model, (1, 28, 28))
        else:
            model = l2l.vision.models.MiniImagenetCNN(ways)
            summary(model, (3, 84, 84))
        model.load_state_dict(torch.load(explainer.model))
        print(f'loaded {explainer.model}')
        model.to(device)
        explainer.model = l2l.algorithms.MAML(model, lr=explainer.adapt_lr,
                                              first_order=False)
    
    return explainer


class MAMLExplainer(ExplainerBase):
    def __init__(self, adapt_lr, **kwargs):
        super().__init__(**kwargs)
        self.meta_params = list(self.model.parameters())
        self.adapt_lr = adapt_lr
        
    def set_src_test_hessian(self, src_test_errors):
        for error in src_test_errors:
            self.add_src_test_error(error)

        print(f'done (shape {self.src_test_hessian.shape})')

    def set_trg_param_matrix(self, trg_param_matrix=None,
                             hessian=None, gradient=None, trg_train_error=None,
                             params=None):
        assert self.src_param_matrix is not None
        assert isinstance(params, list)
        self.params = params
        if self.n_meta_param is None:
            self.n_meta_param = len(self.src_param_matrix)
        
        if trg_param_matrix is not None:
            self.trg_param_matrix = trg_param_matrix
        elif hessian is not None:
            assert isinstance(hessian, np.ndarray)
            assert hessian.shape == (self.n_meta_param, self.n_meta_param)
            m = (np.eye(self.n_meta_param) - self.adapt_lr * hessian)
            self.trg_param_matrix = m.dot(self.src_param_matrix)
        else:
            assert isinstance(trg_train_error, torch.Tensor)
            if gradient is None:
                v = higher_grad(trg_train_error, params)
            else:
                v = gradient
            src_mat = torch.Tensor(self.src_param_matrix).to(v.device)
            v_dot_src_mat = torch.tensordot(v, src_mat, dims=1)
            h_dot_src_mat =\
                higher_grad(v_dot_src_mat, params, create_graph=False).to('cpu').numpy()
            self.trg_param_matrix = self.src_param_matrix - self.adapt_lr * h_dot_src_mat

    def calc_src_task_scores(self, y):
        assert self.src_param_matrix is not None
        # [Note] self.meta_params may be disconeccted from the graph
        # by save_explainer and load_explainer
        variables = list(self.model.parameters())
        
        # [Note] materialize_grads does not not exists in torch <= 1.13
        m = higher_grad(y, variables,
                        # allow_unused=True, materialize_grads=True
                        ).detach().to('cpu').numpy()
        scores = m.dot(self.src_param_matrix)
        return scores


class MAMLExplainerOPA(ExplainerOPA, MAMLExplainer):
    
    def __init__(self, adapt_lr, **kwargs):
        super().__init__(**kwargs)
        self.meta_params = list(self.model.parameters())
        self.adapt_lr = adapt_lr

    def calc_src_task_scores(self, y):
        return MAMLExplainer.calc_src_task_scores(self, y)


OPAExplainer = MAMLExplainerOPA


def explan_adaptation(explainer, task,
                      preprocess=(lambda x: x),
                      loss=nn.CrossEntropyLoss(reduction='mean'),
                      shots=5,
                      ways=5,
                      num_task=100
                      ):
    learner = explainer.model.clone()
    task_feature = preprocess(task)
    result = xfast_adapt(task_feature, learner, loss, shots, ways,
                         src_param_matrix=explainer.src_param_matrix,
                         output_hessian=False, output_gradient=False, output_tensor=True)
    # explainer.set_trg_param_matrix(trg_param_matrix=result['train']['sensitivity'],
    #                                params=list(learner.parameters()))
    idxes, scores = explainer.explain(-result['evaluation']['error'], top_k=num_task)
    
    result = tensor2numpy(result)

    learner = explainer.model.clone()
    zs_result = zero_shot(task_feature, learner, loss, result['evaluation_indices'],
                          output_tensor=False)
    result.update(zs_result)
    return result, idxes, scores


def explain_test_performance(explainer, train_taskset, test_taskset,
                             preprocess=(lambda x: x),
                             loss=nn.CrossEntropyLoss(reduction='mean'),
                             shots=5,
                             ways=5,
                             num_train_task=100,
                             num_test_task=100
                             ):

    train_task_idx_lists = []
    train_task_score_lists = []
    results = []
    for ii in tqdm(range(num_test_task)):
        task = test_taskset[ii]
        result, idxes, scores =\
            explan_adaptation(explainer, task,
                              preprocess=preprocess,
                              loss=loss,
                              shots=shots,
                              ways=ways,
                              num_task=num_train_task
                              )
        train_task_idx_lists.append(idxes)
        train_task_score_lists.append(scores)
        results.append(result)

    df = pd.DataFrame({
        'test_task_idx': list(range(num_test_task)),
        'test_accuracy': [r['evaluation']['accuracy'] for r in results],
        'test_error': [r['evaluation']['error'] for r in results],
        'adaptation_accuracy': [r['adaptation']['accuracy'] for r in results],
        'adaptation_error': [r['adaptation']['error'] for r in results],
        'train_accuracy': [r['train']['accuracy'] for r in results],
        'train_error': [r['train']['accuracy'] for r in results],
        'zeroshot_accuracy': [r['zero_shot']['accuracy'] for r in results],
        'zeroshot_error': [r['zero_shot']['error'] for r in results],
        'train_task_idx': train_task_idx_lists,
        'train_task_score': train_task_score_lists
    })

    return df
