import os
import sys
import random
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader


class APO:
  def __init__(self, model, base_optimizer, meta_optimizer, num_meta_steps,
               meta_interval=1, train_dataloader=None, lam=0,
               batch_size_prime=100, device='cuda:0'):

    self.model = model
    self.base_optimizer = base_optimizer
    self.meta_optimizer = meta_optimizer
    self.num_meta_steps = num_meta_steps
    self.meta_interval = meta_interval

    self.lam = lam
    self.batch_size_prime = batch_size_prime

    self.device = device
    self.steps = 0

    if train_dataloader is not None:
      self.train_dataloader_prime = DataLoader(
          train_dataloader.dataset, batch_size=batch_size_prime, shuffle=True,
          pin_memory=True, num_workers=0
      )
      self.train_iter = iter(self.train_dataloader_prime)

  def step(self, loss_fn):
    loss, predictions = loss_fn(self.model)
    self.base_optimizer.zero_grad()
    loss.backward(retain_graph=True)

    if self.steps % self.meta_interval == 0:
      for j in range(self.num_meta_steps):
        def closure():
          updated_model = self.base_optimizer.update(take_step=False)
          f_approx, updated_model_predictions = loss_fn(updated_model, updated=True)

          D = torch.zeros(1, device=self.device)

          if self.lam > 0:
            try:
              inputs_prime, _ = self.train_iter.next()
              inputs_prime = inputs_prime.to(self.device)
            except StopIteration:
              self.train_iter = iter(self.train_dataloader_prime)
              inputs_prime, _ = self.train_iter.next()
              inputs_prime = inputs_prime.to(self.device)
            except:
              inputs_prime = None

            output_orig_model = self.model(inputs_prime).detach()
            output_updated_model = updated_model(inputs_prime)
            D = torch.norm(output_orig_model - output_updated_model) ** 2

          second_term = (0.5 * self.lam * D) / self.batch_size_prime
          J = f_approx + second_term
          J.backward()
          return J

        self.meta_optimizer.zero_grad()
        self.meta_optimizer.step(closure)
        self.base_optimizer.zero_grad()
        loss.backward(retain_graph=True)

    self.steps += 1
    self.base_optimizer.update(take_step=True)
    return loss, predictions
