import torch
import numpy as np

from torch.nn.utils.clip_grad import clip_grad_norm_
from maml.utils import accuracy

from numpy.linalg import multi_dot

# def get_grad_norm(parameters, norm_type=2):
#     if isinstance(parameters, torch.Tensor):
#         parameters = [parameters]
#     parameters = list(filter(lambda p: p.grad is not None, parameters))
#     norm_type = float(norm_type)
#     total_norm = 0
#     for p in parameters:
#         param_norm = p.grad.data.norm(norm_type)
#         total_norm += param_norm.item() ** norm_type
#     total_norm = total_norm ** (1. / norm_type)

    # return total_norm

class MetaLearner(object):
    def __init__(self, model, optimizer, fast_lr, loss_func,
                 method, num_updates, inner_loop_grad_clip,
                 device, collect_accuracies=False): #, classifier_schedule=10, 
        self._model = model
        self._fast_lr = fast_lr
        self._optimizer = optimizer
        self._loss_func = loss_func
        self._method = method
        self._first_order = False
        self._num_updates = num_updates
        self._inner_loop_grad_clip = inner_loop_grad_clip
        self._collect_accuracies = collect_accuracies
        self._device = device
        # self._classifier_schedule = classifier_schedule
        self._grads_mean = []

        self.to(device)

        self._reset_measurements()

    def _reset_measurements(self):
        self._count_iters = 0.0
        self._cum_loss = 0.0
        self._cum_accuracy = 0.0

    def _update_measurements(self, task, loss, preds):
        self._count_iters += 1.0
        self._cum_loss += loss.data.cpu().numpy()
        if self._collect_accuracies:
            self._cum_accuracy += accuracy(
                preds, task.y).data.cpu().numpy()

    def _pop_measurements(self):
        measurements = {}
        loss = self._cum_loss / self._count_iters
        measurements['loss'] = loss
        if self._collect_accuracies:
            accuracy = self._cum_accuracy / self._count_iters
            measurements['accuracy'] = accuracy
        self._reset_measurements()
        return measurements

    def measure(self, tasks, train_tasks=None, adapted_params_list=None):
        """Measures performance on tasks. Either train_tasks has to be a list
        of training task for computing embeddings, or adapted_params_list and
        embeddings_list have to contain adapted_params and embeddings"""
        if adapted_params_list is None:
            adapted_params_list = [None] * len(tasks)
        for i in range(len(tasks)):
            params = adapted_params_list[i]
            if params is None:
                params = self._model.param_dict
            task = tasks[i]
            preds = self._model(task, params=params)
            loss = self._loss_func(preds, task.y)
            self._update_measurements(task, loss, preds)

        measurements = self._pop_measurements()
        return measurements

    def measure_each(self, tasks, train_tasks=None, adapted_params_list=None, accuracy=False):
        """Measures performance on tasks. Either train_tasks has to be a list
        of training task for computing embeddings, or adapted_params_list and
        embeddings_list have to contain adapted_params and embeddings"""
        """Return a list of losses and accuracies"""
        if adapted_params_list is None:
            adapted_params_list = [None] * len(tasks)

        accuracies = []
        for i in range(len(tasks)):
            params = adapted_params_list[i]
            if params is None:
                params = self._model.param_dict
            task = tasks[i]
            preds = self._model(task, params=params)
            
            if accurary:
                pred_y = np.argmax(preds.data.cpu().numpy(), axis=-1)
                result = np.mean(
                    task.y.data.cpu().numpy() == 
                    np.argmax(preds.data.cpu().numpy(), axis=-1))
                accuracies.append(result)
            else:
                accuracies.append(np.mean((preds.data.cpu().numpy()-task.y.data.cpu().numpy())**2))

        return accuracies

    def update_params(self, loss, params):
        """Apply one step of gradient descent on the loss function `loss`,
        with step-size `self._fast_lr`, and returns the updated parameters.
        """
        create_graph = not self._first_order
        grads = torch.autograd.grad(loss, params.values(),
                                    create_graph=create_graph, allow_unused=True)
        for (name, param), grad in zip(params.items(), grads):
            if self._inner_loop_grad_clip > 0 and grad is not None:
                grad = grad.clamp(min=-self._inner_loop_grad_clip,
                                  max=self._inner_loop_grad_clip)
            if grad is not None:
              params[name] = param - self._fast_lr * grad

        return params

    def adapt(self, train_tasks, val_tasks):
        adapted_params = []
        fast_lr_list = []
        for task, vtask in zip(train_tasks, val_tasks):
            params = self._model.param_dict
            train_x = self._model(task, get_embd=True)[1]
            val_x =  self._model(vtask, get_embd=True)[1]
            train_x = train_x.data.cpu().numpy()
            val_x = val_x.data.cpu().numpy()
            # import pdb; pdb.set_trace()
            numerator = np.trace(multi_dot([train_x.T, train_x, val_x.T, val_x]))
            denominator = np.trace(multi_dot([train_x.T, train_x, val_x.T, val_x, train_x.T, train_x]))
            fast_lr = 0.5*len(train_x)*numerator / denominator
            if self._method == 'adaptive':
                self._fast_lr = fast_lr#0.5*len(train_x)*numerator / denominator
            for i in range(self._num_updates):
                preds = self._model(task, params=params)
                loss = self._loss_func(preds, task.y)
                params = self.update_params(loss, params=params)
                if i == 0:
                    self._update_measurements(task, loss, preds)
            adapted_params.append(params)
            fast_lr_list.append(fast_lr)

        measurements = self._pop_measurements()
        return measurements, adapted_params, fast_lr_list

    def step(self, adapted_params_list, val_tasks,
             is_training):
        self._optimizer.zero_grad()
        post_update_losses = []

        for adapted_params, task in zip(
                adapted_params_list, val_tasks):
            preds = self._model(task, params=adapted_params)
            loss = self._loss_func(preds, task.y)
            post_update_losses.append(loss)
            self._update_measurements(task, loss, preds)

        mean_loss = torch.mean(torch.stack(post_update_losses))
        if is_training:
            mean_loss.backward()
            self._optimizer.step()

            measurements = self._pop_measurements()
            return measurements
        else:
            measurements = self._pop_measurements()
            return preds, measurements

    def step_nonmaml(self, tasks_pack, is_training):
        self._optimizer.zero_grad()
        post_update_losses = []

        # concatenate train and val tasks points
        for tasks in tasks_pack:
            for task in tasks:
                preds = self._model(task)
                loss = self._loss_func(preds, task.y)
                post_update_losses.append(loss)
                self._update_measurements(task, loss, preds)

        mean_loss = torch.mean(torch.stack(post_update_losses))
        if is_training:
            mean_loss.backward()
            self._optimizer.step()

            measurements = self._pop_measurements()
            return measurements
        else:
            measurements = self._pop_measurements()
            return preds, measurements


    def to(self, device, **kwargs):
        self._device = device
        self._model.to(device, **kwargs)

    def state_dict(self):
        state = {
            'model_state_dict': self._model.state_dict(),
            'optimizer': self._optimizer.state_dict()
        }
        return state
