import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torchsummary import summary
from xmeta.utils.evaluation import accuracy
from xmeta.explainer.explainer import ExplainerBase, ExplainerOPA, save_explainer
from xmeta.utils.data import ImpureTasksets, get_tasksets
from xmeta.utils.seed import set_seed
from xmeta.networks.simple_networks import ConvFeature as Convnet
from xmeta.utils.tensor import tensor2numpy
from xmeta.utils.higher_grad import higher_grad
import os
import pickle
import pandas as pd
from tqdm import tqdm

TASK_INTERVAL = 512


def pairwise_distances_logits(a, b):   # a: n_query x n_feature, b: n_way x n_feature
    n = a.shape[0]
    m = b.shape[0]
    logits = -((a.unsqueeze(1).expand(n, m, -1) -
                b.unsqueeze(0).expand(n, m, -1))**2).sum(dim=2)
    return logits


def fast_adapt(model, batch, ways, shots, query_num, metric=None, device=None):
    if metric is None:
        metric = pairwise_distances_logits
    if device is None:
        device = model.device()
    data, labels = batch
    data = data.to(device)
    labels = labels.to(device)
    n_items = shots * ways

    # Sort data samples by labels
    # TODO: Can this be replaced by ConsecutiveLabels ?
    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)

    # Compute support and query embeddings
    embeddings = model(data)
    support_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (shots + query_num)
    for offset in range(shots):
        support_indices[selection + offset] = True
    query_indices = torch.from_numpy(~support_indices)
    support_indices = torch.from_numpy(support_indices)
    support = embeddings[support_indices]
    support = support.reshape(ways, shots, -1).mean(dim=1)
    query = embeddings[query_indices]
    labels = labels[query_indices].long()

    logits = pairwise_distances_logits(query, support)
    loss = F.cross_entropy(logits, labels)
    acc = accuracy(logits, labels)
    return loss, acc


def xfast_adapt(model, batch, ways, shots, query_num, metric=None, device=None,
                # batch,
                # learner, loss, shots, ways,
                src_param_matrix=None,
                output_hessian=False, output_gradient=False, output_tensor=True):
    if metric is None:
        metric = pairwise_distances_logits
    if device is None:
        device = model.device()

    data, labels = batch
    data = data.to(device)
    labels = labels.to(device)

    # Sort data samples by labels
    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)

    # Compute support and query embeddings
    embeddings = model(data)
    support_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (shots + query_num)
    for offset in range(shots):
        support_indices[selection + offset] = True
    query_indices = torch.from_numpy(~support_indices)
    support_indices = torch.from_numpy(support_indices)
    support = embeddings[support_indices]
    query_adpt = support
    support = support.reshape(ways, shots, -1).mean(dim=1)
    query = embeddings[query_indices]
    labels_adpt = labels[support_indices]
    labels = labels[query_indices].long()
    
    logits_adpt = pairwise_distances_logits(query_adpt, support)
    loss_adpt = F.cross_entropy(logits_adpt, labels)
    acc_adpt = accuracy(logits_adpt, labels)
    logits = pairwise_distances_logits(query, support)
    loss = F.cross_entropy(logits, labels_adpt)
    acc = accuracy(logits, labels)
    # adaptation_indices = np.zeros(data.size(0), dtype=bool)
    # adaptation_indices[np.arange(shots * ways) * 2] = True
    # evaluation_indices = torch.from_numpy(~adaptation_indices)
    # adaptation_indices = torch.from_numpy(adaptation_indices)
    # adaptation_data, adaptation_labels =\
    #     data[adaptation_indices], labels[adaptation_indices]
    # evaluation_data, evaluation_labels =\
    #     data[evaluation_indices], labels[evaluation_indices]
    
    # if output_hessian:
    #     params = list(learner.parameters())
    #     h = higher_grad(train_error, params, params).detach().to('cpu').numpy()
    # else:
    #     h = None
    
    # if output_gradient:
    #     params = list(learner.parameters())
    #     v = higher_grad(train_error, params)
    # else:
    #     v = None    

    h = None
    v = None
    
    if src_param_matrix is not None:  # [NOTE] easily raise torch.OutOfMemoryError
        meta_param_tensors = list(model.parameters())
        # param_tensors = [x.clone() for x in meta_param_tensors] + [support.reshape(-1)]
        param_tensors = [support.reshape(-1)]
        grads = []
        for _tensor in param_tensors:
            _v = higher_grad(_tensor, meta_param_tensors,
                             allow_unused=True, materialize_grads=True)
            src_mat = torch.Tensor(src_param_matrix).to(v.device)
            _v_dot_src_mat = torch.tensordot(_v, src_mat, dims=1).detach().to('cpu').numpy()
            grads.append(_v_dot_src_mat)
        
        task_sensitivity = np.vstack([src_param_matrix, np.vstack(grads)])
    else:
        task_sensitivity = None

    adapt_pred = logits
    adapt_error = loss
    adapt_acc = acc

    valid_pred = logits
    valid_error = loss
    valid_acc = acc
    
    if not output_tensor:
        # train_error = train_error.detach().to('cpu').numpy()
        # train_pred = train_pred.detach().to('cpu').numpy()
        # train_acc = train_acc.detach().to('cpu').numpy()
        adapt_error = adapt_error.detach().to('cpu').numpy()
        adapt_pred = adapt_pred.detach().to('cpu').numpy()
        adapt_acc = adapt_acc.detach().to('cpu').numpy()
        valid_error = valid_error.detach().to('cpu').numpy()
        valid_pred = valid_pred.detach().to('cpu').numpy()
        valid_acc = valid_acc.detach().to('cpu').numpy()
        support_indices = support_indices.detach().to('cpu').numpy()
        query_indices = query_indices.detach().to('cpu').numpy()
        v = v.to('cpu').numpy()

    # [TODO]  correct misnames train -> before, error -> loss, prediction -> logis
    return {'train': {'error': None, 'prediction': None,
                      'accuracy': None, 'hessian': h, 'gradient': v,
                      'sensitivity': task_sensitivity},
            'adaptation': {'error': loss_adpt, 'prediction': logits_adpt,
                           'accuracy': acc_adpt},
            'evaluation': {'error': loss, 'prediction': logits,
                           'accuracy': acc},
            'adaptation_indices': support_indices,
            'evaluation_indices': query_indices
            }


def meta_test(model, tasksets, device=torch.device('cuda'),
              loss=nn.CrossEntropyLoss(reduction='mean'),
              shots=5, ways=5, queries=5, num_test_tasks: int = -1,
              num_test_iterations: int = 1024):
    n_iter = min(num_test_tasks, num_test_iterations)\
        if num_test_tasks > 0 else num_test_iterations
    
    max_acc = -1.
    min_acc = 10.
    meta_test_error = 0.0
    meta_test_accuracy = 0.0
    for ii in range(n_iter):
        # learner = model.clone()
        if num_test_tasks > 0:
            batch_test = tasksets.test[ii]
        else:
            batch_test = tasksets.test.sample()
        
        evaluation_error, evaluation_accuracy =\
            fast_adapt(model=model,
                       batch=batch_test,
                       # loss,
                       shots=shots,
                       ways=ways,
                       query_num=queries,
                       metric=pairwise_distances_logits,
                       device=device,
                       )
        max_acc = max(max_acc, evaluation_accuracy.item())
        min_acc = min(min_acc, evaluation_accuracy.item())
        meta_test_error += evaluation_error.item()
        meta_test_accuracy += evaluation_accuracy.item()
    
    meta_test_error = meta_test_error / n_iter
    meta_test_accuracy = meta_test_accuracy / n_iter

    print('Meta Test Error', meta_test_error)
    print('Meta Test Accuracy', meta_test_accuracy)
    print('Min Test Accuracy', min_acc)
    print('Max Test Accuracy', max_acc)

    return meta_test_error, meta_test_accuracy


def setup_experiment(
                     train_way=5,
                     train_shot=5,
                     train_query=5,
                     test_way=5,
                     test_shot=5,
                     test_query=5,
                     num_tasks=None,
                     num_test_tasks=None,
                     experiment_dir=None,
                     explainer_path=None,
                     impurity_dict_path=None,
                     # train_mask_tasks=None,
                     train_noise_tasks=None,
                     # train_shuffle_tasks=None,
                     # train_dark_tasks=None,
                     # train_recolor_tasks=None,
                     # test_recolor_tasks=None,
                     # train_bgr_tasks=None,
                     # test_bgr_tasks=None,
                     seed=42,
                     # score_shuffling=False,
                     dataset='mini-imagenet',
                     device=torch.device('cpu'),
                    ):
    if not os.path.exists(experiment_dir):
        os.makedirs(experiment_dir)
        print(f'created {experiment_dir}')
    else:
        print(f'found {experiment_dir}')

    if num_test_tasks is None:
        num_test_tasks = num_tasks
    
    _tasksets = get_tasksets(seed=seed,
                             name=dataset,
                             train_ways=train_way,
                             train_samples=train_shot + train_query,
                             test_ways=test_way,
                             test_samples=test_shot + test_query,
                             num_tasks=num_tasks,
                             root='~/data',
                            )

    # load explainer
    if explainer_path is not None:
        explainer = load_protoexpl(path=explainer_path,
                                   device=device, dataset=dataset)
        model = explainer.model

    else:
        explainer = None
        model = None
    
    # set_seed(seed)
    
    def feature(x):
        return x

    if impurity_dict_path is None:
        impurity_dict_path = os.path.join(experiment_dir, "index_dict.pkl")
    if os.path.exists(impurity_dict_path):
        with open(impurity_dict_path, 'rb') as f:
            impurity_dict = pickle.load(f)
        print(f'loaded {impurity_dict}')
        # if train_mask_tasks is None:
        #     train_mask_tasks = len(impurity_dict['train_mask_tasks'])
        if train_noise_tasks is None:
            train_noise_tasks = len(impurity_dict['train_noise_tasks'])
        # if train_shuffle_tasks is None:
        #     train_shuffle_tasks = len(impurity_dict['train_shuffle_tasks'])
        # if train_dark_tasks is None:
        #     train_dark_tasks = len(impurity_dict['train_dark_tasks'])
        # if train_recolor_tasks is None:
        #     train_recolor_tasks = len(impurity_dict['train_recolor_tasks'])
        # if train_bgr_tasks is None:
        #     train_bgr_tasks = len(impurity_dict['train_bgr_tasks'])
    else:
        impurity_dict = {'train_mask_tasks': [],
                         'train_noise_tasks': [],
                         'train_shuffle_tasks': [],
                         'train_dark_tasks': [],
                         'train_recolor_tasks': [],
                         'train_bgr_tasks': [],
                         }
 
    if (# True
        # (train_mask_tasks is not None) or\
        (train_noise_tasks is not None)  # or\
        # (train_shuffle_tasks is not None) or\
        # (train_dark_tasks is not None) or\
        # (train_recolor_tasks is not None) or\
        # (train_bgr_tasks is not None)
    ):
        # set_seed(seed)
        tasksets = ImpureTasksets(_tasksets, num_tasks=num_tasks,
                                  ways=train_way, shots=train_shot,
                                  # train_mask_labels=mask_labels,
                                  # train_mask_tasks=mask_tasks,
                                  train_noise_tasks=train_noise_tasks,
                                  # train_shuffle_tasks=shuffle_tasks,
                                  # train_dark_tasks=dark_tasks,
                                  # train_recolor_tasks=recolor_tasks,
                                  # train_bgr_tasks=bgr_tasks,
                                  seed=seed
                            )
        impurity_dict = tasksets.impurity_dict
        # if score_shuffling and train_shuffle_tasks > 0:
        #     impurity_dict['num_correct_labels'] = compare_tasksets(
        #         tasksets.train, _tasksets.train, idxes=impurity_dict['train_shuffle_tasks'])
    else:
        tasksets = _tasksets

    tasks_train = tasksets.train
    tasks_test = tasksets.test
    # if test_recolor_tasks is not None:
    #     tasks_test = pollute_tasks(tasks_test, num_tasks=num_test_tasks,
    #                                test_recolor_tasks=test_recolor_tasks)
    # if test_bgr_tasks is not None:
    #     tasks_test = pollute_tasks(tasks_test, num_tasks=num_test_tasks,
    #                                test_bgr_tasks=test_bgr_tasks)
    
    return tasks_train, tasks_test, explainer, model, feature, impurity_dict


save_protoexpl = save_explainer


def load_protoexpl(path: str,
                   device=torch.device('cpu'),
                   dataset: str = 'mini-imagenet'):
    with open(path, 'rb') as f:
        explainer = pickle.load(f)
        print(f'loaded {path}')
    if isinstance(explainer.model, str):
        assert explainer.model[-4:] == '.pth'
        if dataset == 'omniglot':
            model = Convnet(x_dim=1)
            summary(model, (1, 28, 28))
        else:
            model = Convnet()
            summary(model, (3, 84, 84))
        model.load_state_dict(torch.load(explainer.model))
        print(f'loaded {explainer.model}')
        model.to(device)
        explainer.model = model
    
    return explainer


def explan_adaptation(explainer, task,
                      preprocess=(lambda x: x),
                      ways=5,
                      shots=5,
                      query_num=5,
                      top_k=100,
                      device=None,
                      ):
    model = explainer.model
    task = preprocess(task)
    result = xfast_adapt(model=model,
                         batch=task,
                         ways=ways,
                         shots=shots,
                         query_num=query_num,
                         metric=pairwise_distances_logits,
                         # src_param_matrix=explainer.src_param_matrix,
                         device=device
                         )
    # explainer.set_trg_param_matrix(trg_param_matrix=result['train']['sensitivity'],
    #                                params=list(model.parameters()))
    idxes, scores = explainer.explain(-result['evaluation']['error'], top_k=top_k)
    result = tensor2numpy(result)

    return result, idxes, scores


def explain_test_performance(explainer,
                             test_taskset,
                             preprocess=(lambda x: x),
                             ways=5,
                             shots=5,
                             query_num=5,
                             num_train_task=100,
                             num_test_task=100,
                             device=None
                             ):

    train_task_idx_lists = []
    train_task_score_lists = []
    results = []
    for ii in tqdm(range(num_test_task)):
        task = test_taskset[ii]
        result, idxes, scores =\
            explan_adaptation(explainer, task,
                              preprocess=preprocess,
                              ways=ways,
                              shots=shots,
                              query_num=query_num,
                              top_k=num_train_task,
                              device=device
                              )
        train_task_idx_lists.append(idxes)
        train_task_score_lists.append(scores)
        results.append(result)

    df = pd.DataFrame({
        'test_task_idx': list(range(num_test_task)),
        'test_accuracy': [r['evaluation']['accuracy'] for r in results],
        'test_error': [r['evaluation']['error'] for r in results],
        'adaptation_accuracy': [r['adaptation']['accuracy'] for r in results],
        'adaptation_error': [r['adaptation']['error'] for r in results],
        'train_accuracy': [r['train']['accuracy'] for r in results],
        'train_error': [r['train']['accuracy'] for r in results],
        # 'zeroshot_accuracy': [r['zero_shot']['accuracy'] for r in results],
        # 'zeroshot_error': [r['zero_shot']['error'] for r in results],
        'train_task_idx': train_task_idx_lists,
        'train_task_score': train_task_score_lists
    })

    return df


class ProtonetExplainer(ExplainerBase):
       
    def set_trg_param_matrix(self, trg_param_matrix=None,
                             gradient=None, trg_train_error=None,
                             params=None):
        assert self.src_param_matrix is not None
        assert isinstance(params, list)
        self.params = params
        if self.n_meta_param is None:
            self.n_meta_param = len(self.src_param_matrix)
        
        if trg_param_matrix is not None:
            self.trg_param_matrix = trg_param_matrix
        else:
            print('set_trg_param_matrix not implemented')
    
    def calc_src_task_scores(self, y):
        assert self.src_param_matrix is not None
        # [Note] self.meta_params may be disconeccted from the graph
        # by save_explainer and load_explainer
        variables = list(self.model.parameters())
        
        # [Note] materialize_grads does not not exists in torch <= 1.13
        m = higher_grad(y, variables,
                        # allow_unused=True, materialize_grads=True
                        ).detach().to('cpu').numpy()
        scores = m.dot(self.src_param_matrix)
        return scores


class ProtonetExplainerOPA(ExplainerOPA, ProtonetExplainer):
    def calc_src_task_scores(self, y):
        return ProtonetExplainer.calc_src_task_scores(self, y)

  
