import os

import numpy as np
import matplotlib.pyplot as plt
plt.style.use('seaborn')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_convert

import higher

from sacred import Experiment
from sacred.observers import FileStorageObserver

from maml import MAML
from utils import Gaussian, Uniform
from models import FCNet
import sin_dloader

EXPERIMENT_NAME = "maml_sinusoid"
RUN_DIR = "runs"
ex = Experiment(EXPERIMENT_NAME)
ex.observers.append(FileStorageObserver(os.path.join(RUN_DIR, EXPERIMENT_NAME)))

@ex.config
def get_config():
  ## DATA
  n_tasks = 100
  n_supp = 20
  n_query = 10
  eps_per_batch = 25

  k_supp = 10
  k_query = 10
  val_tasks = 100

  train_ampl_range = [1.0, 4.0]
  train_phase_range = [0.0, np.pi / 2 ]
  train_obs_noise = 1.0

  val_ampl_range = [3.0, 5.0]
  val_phase_range = [0.0, np.pi / 2]
  val_obs_noise = 1.0

  ## MAML
  inner_steps = 5
  inner_lr = 1e-3

  ## Optimization
  epochs = 100
  meta_lr = 1e-3

  ## MODEL
  nhid = 40
  nlayers = 4

  ## MISC
  cuda = True


def get_model(nhid, nlayers, cuda):
  model = FCNet(1, nhid, 1, nlayers)
  if cuda:
    return model.cuda()
  else:
    return model


@ex.capture
def get_maml_learner(nhid, nlayers, inner_lr, meta_lr, inner_steps, cuda):
  model = get_model(nhid, nlayers, cuda)
  inner_opt = torch.optim.SGD(model.parameters(), 1e-3)
  learner = MAML(model, inner_opt, inner_steps, loss_fn=lambda x,y: ((x-y)**2).mean(0).sum())
  meta_opt = torch.optim.Adam(model.parameters(), lr=0.001)
  return learner, meta_opt


@ex.capture
def get_sinusoid_loaders(n_tasks, n_supp, n_query, val_tasks, k_supp, k_query,
                        eps_per_batch, train_ampl_range, train_phase_range,
                        val_ampl_range, val_phase_range, train_obs_noise, val_obs_noise):

  train_dset = sin_dloader.SinusoidDataSet(train_ampl_range[0], train_ampl_range[1],
                                          train_phase_range[0], train_phase_range[1],
                                          n_tasks, n_supp, n_query, train_obs_noise)
  train_batch_sampler = sin_dloader.SinusoidEpisodeBatchSampler(
                              eps_per_batch, n_supp, n_query, n_tasks)

  val_dset = sin_dloader.SinusoidDataSet(val_ampl_range[0], val_ampl_range[1],
                                        val_phase_range[0], val_phase_range[1],
                                        val_tasks, k_supp, k_query, val_obs_noise)
  val_batch_sampler = sin_dloader.SinusoidEpisodeBatchSampler(
                              eps_per_batch, k_supp, k_query, val_tasks, False)

  train_dloader = DataLoader(train_dset, batch_sampler=train_batch_sampler, collate_fn=train_batch_sampler.collate)
  val_dloader = DataLoader(val_dset, batch_sampler=val_batch_sampler, collate_fn=val_batch_sampler.collate)

  train_sin_sampler = train_dset.sinusoid_sampler
  val_sin_sampler = val_dset.sinusoid_sampler
  return train_dloader, val_dloader, train_sin_sampler, val_sin_sampler


def plot_model_pred(episode, sin, learner):
  x_range = np.linspace(-5.0, 5.0, 100)
  episode[0]['query_im'] = torch.tensor(np.linspace(-5.0, 5.0, 100)).float().unsqueeze(1).cuda()
  episode[0]['query_labels'] = torch.tensor(sin[0](np.linspace(-5.0, 5.0, 100))).float().unsqueeze(1).cuda()
  _, out = learner.eval(episode, inner_steps=15)
  model_pred = out[0][0].squeeze().cpu()
  plt.scatter(episode[0]['support_im'].squeeze().cpu(), episode[0]['support_labels'].squeeze().cpu(), marker='s', color='purple')
  plt.plot(x_range, sin[0](x_range), ls='--', color='red', label='True')
  plt.plot(x_range, model_pred, ls='-', color='blue', label='Model')
  plt.legend()
  plt.tight_layout()
  plt.show(block=False)
  plt.pause(3)
  plt.close()


def get_episodes(sampler, n_supp, n_query, noise, ep_count = 5):
  data = [sampler.sample_episode(n_supp, n_query, noise) for i in range(ep_count)]
  eps = [default_convert(d[0]) for d in data]
  sinusoids = [d[1] for d in data]
  for i in range(len(eps)):
    for k in eps[i]:
      eps[i][k] = eps[i][k].unsqueeze(1).cuda()
  return eps, sinusoids

def process_episodes(episodes):
  for i in range(len(episodes)):
    for k in episodes[i]:
      episodes[i][k] = episodes[i][k].unsqueeze(1).cuda()

def eval_step(maml, episodes):
  qry_losses, qry_out = maml.eval(episodes)
  return np.mean([q.item() for q in qry_losses])

@ex.automain
def main(epochs, _run):
  train_loader, val_loader, train_sampler, val_sampler = get_sinusoid_loaders()
  learner, meta_opt = get_maml_learner()
  
  for epoch in range(epochs):
    losses = []
    for episodes in train_loader:
      process_episodes(episodes)
      meta_opt.zero_grad()
      loss, out = learner.compute_meta_grad(episodes)
      meta_opt.step()
      losses.append(np.mean([l.item() for l in loss]))
    avg_train_loss = np.mean(losses)
    _run.log_scalar("train.loss", avg_train_loss)

    val_loss = []
    for episodes in val_loader:
      process_episodes(episodes)
      val_loss.append(eval_step(learner, episodes))
    avg_val_loss = np.mean(val_loss)
    _run.log_scalar("val.loss",avg_val_loss)


    print("Epoch : {}   -    Train Loss: {}   -    Val Loss: {}".format(epoch, avg_train_loss, avg_val_loss))
    # train_ep, train_sin = get_episodes(train_sampler, 10, 10, 0.1, 1)
    # val_ep, val_sin = get_episodes(val_sampler, 10, 10, 0.1, 1)
    # plot_model_pred(train_ep, train_sin, learner)
    
      



      
