import matplotlib.pyplot as plt

import os
import torch
import torch.nn as nn
import numpy as np
from maml.datasets.metadataset import Task
from maml.models.fully_connected import FullyConnectedModel
from maml.metalearner import MetaLearner


_save_folder = './train_dir/maml/'
save_name = 'maml_fc_30000.pt'
device = 'cuda:0'
save_path = os.path.join(_save_folder, save_name)
checkpoint = torch.load(save_path)

model = FullyConnectedModel(
    input_size=1,
    output_size=1,
    hidden_sizes=[40,40],
    disable_norm=True)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

amp_range =[0.1, 5] 
phase_range = [0, np.pi]

input_range = [-5, 5.0]
num_samples_per_function = 10

amp = np.random.uniform(amp_range[0], amp_range[1])
amp = 5
phase = np.random.uniform(phase_range[0], phase_range[1])
init_inputs = np.random.uniform(input_range[0], input_range[1], [num_samples_per_function, 1])

outputs = amp * np.sin(init_inputs - phase)
x_axis = np.linspace(-5.0, 5.0, num=1000)
true_y = amp * np.sin(x_axis - phase)



## Quadratic function
# a =np.random.uniform(0.1, 0.5)
# b = np.random.uniform(-0.3, 0.3)
# c = np.random.uniform(-1, 1)
# x_axis = np.linspace(-5.0, 5.0, num=1000)
# true_y = a * x_axis**2 + b * x_axis +c


support = Task(torch.from_numpy(init_inputs[:num_samples_per_function]).float().to(device), torch.from_numpy(outputs[:num_samples_per_function]).float().to(device), 'train')
query = Task(torch.from_numpy(x_axis.reshape(-1,1)).float().to(device), torch.from_numpy(true_y.reshape(-1,1)).float().to(device), 'test')

loss_func = nn.MSELoss()
num_updates = 50
fast_lr = 0.01
meta_learner = MetaLearner(
    model, optimizer, fast_lr=fast_lr,
    loss_func=loss_func, first_order=False,
    num_updates=num_updates,
    inner_loop_grad_clip=-1,
    device=device,
    classifier_schedule=10)

(pre_train_loss, adapted_params) = meta_learner.adapt([support])
preds, post_val_loss = meta_learner.step(adapted_params, [query], False)

print(pre_train_loss, post_val_loss)
support = Task(torch.from_numpy(init_inputs[:num_samples_per_function]).float().to(device), torch.from_numpy(outputs[:num_samples_per_function]).float().to(device), 'train')
query = Task(torch.from_numpy(x_axis.reshape(-1,1)).float().to(device), torch.from_numpy(true_y.reshape(-1,1)).float().to(device), 'test')
random_model = FullyConnectedModel(
    input_size=1,
    output_size=1,
    hidden_sizes=[40,40],
    disable_norm=True)
random_model.to(device)
_save_folder2 = './train_dir/taskwise/'
# save_name = 'maml_fc_50000.pt'
optimizer2 = torch.optim.Adam(random_model.parameters(), lr=fast_lr)
random_model.load_state_dict(torch.load(os.path.join(_save_folder2, 'maml_fc_30000.pt'))['model_state_dict'])
random_init = MetaLearner(
    random_model, optimizer2, fast_lr=fast_lr,
    loss_func=loss_func, first_order=False,
    num_updates=num_updates,
    inner_loop_grad_clip=-1,
    device=device,
    classifier_schedule=10)

pred_y = preds.data.cpu().numpy()
context = init_inputs[:num_samples_per_function]
context_y = amp * np.sin(context - phase)
# context_y = a * context**2 + b * context +c # Quadratic


(pre_train_loss, adapted_params2) = random_init.adapt([support])
preds_random, post_val_loss = random_init.step(adapted_params2, [query], False)
print(pre_train_loss, post_val_loss)
preds_random = preds_random.data.cpu().numpy()


# import pdb; pdb.set_trace()
plt.plot(x_axis.squeeze(), true_y.squeeze(), '--')
plt.plot(x_axis, pred_y)
plt.plot(x_axis, preds_random, '--')

plt.scatter(context.squeeze(), context_y.squeeze(), marker='x')
plt.savefig('MAML vs SGD result')


