import torch
import numpy as np

from torch.nn.utils.clip_grad import clip_grad_norm_

from numpy.linalg import multi_dot


class NTK_MAML(object):
    def __init__(self, dataset, model, fast_lr, loss_func,
                 num_updates, device):
        
        self.meta_dataset = dataset
        self.model = model
        self.fast_lr = fast_lr
        # self.optimizer = optimizer
        self.loss_func = loss_func
        self.num_updates = num_updates
        # self._inner_loop_grad_clip = inner_loop_grad_clip
        self._device = device

        # self._classifier_schedule = classifier_schedule
        # self._grads_mean = []

        self.to(device)

        self._cum_loss = []

    def get_ntk_feature(self, input_x, params=None, grads=None):
        if grads is None:
            out = self.model(input_x, params)
            model.zero_grad()
            grads = torch.autograd.grad(out, params)

        feature = torch.cat([g.reshape(-1) for g in grads], -1)
        feature = feature.detach().data
        return feature


    def update_params(self, loss, params):
        """Apply one step of gradient descent on the loss function `loss`,
        with step-size `self._fast_lr`, and returns the updated parameters.
        """
        grads = torch.autograd.grad(loss, params.values(),
                                    create_graph=True, allow_unused=True)
        for (name, param), grad in zip(params.items(), grads):
            # if self._inner_loop_grad_clip > 0 and grad is not None:
            #     grad = grad.clamp(min=-self._inner_loop_grad_clip,
            #                       max=self._inner_loop_grad_clip)
            if grad is not None:
                params[name] = param - self.fast_lr * grad

        return params, grads

    def adapt(self, task):
        # for x in .x:
        params = self.model.param_dict
        for i in range(self.num_updates):
            preds = self.model(task.x, params=params)
            # loss = self.loss_func(preds, task.y)
            # pred = torch.dot(features, param)
            loss = self.loss_func(preds, task.y)
            # self.measure(task, params=params)
            # preds = self.model(task, params=params)
            # loss = self.loss_func(preds, task.y)
            params, _ = self.update_params(loss, params=params)
            # if i == 0:
            #     self.
                # self._update_measurements(task, loss, preds)
            # pred = torch.dot(features, param)
        # adapted_params.append(params)
        return params
            # fast_lr_list.append(fast_lr)

    # def measure(self, task, params):
    #     preds = self.model(task, params=params)
    #     loss = self.loss_func(preds, task.y)
    #     return loss
    
    def evaluate(self, train_tasks, val_tasks):
        _cum_loss  = []
        ntk_features = []
        for task, vtask in zip(train_tasks, val_tasks):
            

            adapted_params = self.adapt(task)
            val_preds = self.model(task.x, params=adapted_params) 
            loss = self.loss_func(val_preds, task.y)
            _cum_loss.append(loss.detach().data)

            ntk_features_per_task = []
            for _x, _y in zip(task.x, task.y):
                _params = self.model.param_dict
                _loss_train = self.loss_func(self.model(_x), _y)
                _, grad = self.update_params(_loss_train, params=_params)
                ntk_features_per_task.append(self.get_ntk_feature(_x, grads=grad))
            ntk_features.append(torch.stack(ntk_features_per_task))
        return _cum_loss, ntk_features
        
        

    def to(self, device, **kwargs):
        self._device = device
        self.model.to(device, **kwargs)

    def state_dict(self):
        state = {
            'model_state_dict': self._model.state_dict(),
            'optimizer': self._optimizer.state_dict()
        }
        return state
