import torch
import torch.nn as nn

import higher

class MAML():

  def __init__(self, model, inner_opt, inner_steps=1, loss_fn=None):
    self.model = model
    self.inner_opt = inner_opt
    self.inner_steps = inner_steps
    self.loss_fn= loss_fn

  def compute_meta_grad(self, episodes):
    """Runs the MAML algorithm on the episodic data, aggregating the per-task gradients
    with respect to the meta-initialization in the pytorch graph.

    Intended use case:

    ```
    for episodes in loader:
      meta_opt.zero_grad()
      qry_losses = maml.compute_meta_grad(episodes)
      meta_opt.step()
    ```

    Return an array of the losses on the query set --- Note: This is not differentiable.
    """
    qry_losses = []
    qry_outputs = []
    for ep in episodes:
      with higher.innerloop_ctx(
        self.model, self.inner_opt, copy_initial_weights=False
      ) as (fmodel, diffopt):
        for _ in range(self.inner_steps):
          out = fmodel(ep['support_im'])
          loss = self.loss_fn(out, ep['support_labels'])
          diffopt.step(loss)
        qout = fmodel(ep['query_im'])
        qloss = self.loss_fn(qout, ep['query_labels'])
        
        # This will accumulate the loss in the graph
        (qloss / len(episodes)).backward()

        qry_losses.append(qloss.detach())
        qry_outputs.append((qout.detach(), ep['query_labels'].detach()))
    return qry_losses, qry_outputs
  
  def eval(self, episodes, inner_steps=None):
    if inner_steps is None:
      inner_steps = self.inner_steps
    qry_losses = []
    qry_outputs = []
    for ep in episodes:
      with higher.innerloop_ctx(
        self.model, self.inner_opt, track_higher_grads=False
      ) as (fmodel, diffopt):
        for _ in range(inner_steps):
          out = fmodel(ep['support_im'])
          loss = self.loss_fn(out, ep['support_labels'])
          diffopt.step(loss)
        qout = fmodel(ep['query_im'])
        qloss = self.loss_fn(qout, ep['query_labels'])
        qry_losses.append(qloss.detach())
        qry_outputs.append((qout.detach(), ep['query_labels'].detach()))
    return qry_losses, qry_outputs
