import torch
import random
from torch import nn
import numpy as np
import matplotlib
import matplotlib.pyplot as plt; plt.style.use('seaborn-ticks')
# Avoid Type 3 fonts: http://phyletica.org/matplotlib-fonts/
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
# Set font family, bold, and font size
#font = {'family':'normal', 'weight':'normal', 'size': 12}
font = {'size': 15}
matplotlib.rc('font', **font)


activations = {
  'Linear': nn.Identity(),
  'ReLU': nn.ReLU(),
  'ELU': nn.ELU(),
  'Softplus': nn.Softplus(),
  'LeakyReLU': nn.LeakyReLU(),
  'Tanh': nn.Tanh(),
  'Sigmoid': nn.Sigmoid(),
  'Hardsigmoid': nn.Hardsigmoid(),
  'Softmax-1': nn.Softmax(dim=-1),
  'Softmax0': nn.Softmax(dim=0),
  'Softmax1': nn.Softmax(dim=1),
  'Softmax2': nn.Softmax(dim=2)
}

def set_random_seed(seed):
  # Set all random seeds
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available(): 
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def to_tensor(x, device='cpu'):
  # Convert an array to tensor
  x = torch.as_tensor(x, device=device, dtype=torch.float32)
  return x


def data_generator(batch_size, task=1):
  # Generate y = sin(x*PI) where x \in [0,1] for task 1 and x \in [1,2] for task 2
  if task == 1:
    low, high = 0, 1
  elif task == 2:
    low, high = 1, 2
  x = np.random.uniform(low=low, high=high, size=tuple([batch_size, 1]))
  y = np.sin(np.pi * x)
  return to_tensor(x), to_tensor(y)


def plot_model(model, name, imgType='png'):
  fig, ax = plt.subplots()
  x = np.linspace(0, 2, 100)
  y_true = np.sin(x*np.pi)
  with torch.no_grad():
    y_pred = model(to_tensor(x.reshape((100, -1)))).cpu().detach().numpy().reshape(-1)
  plt.plot(x, y_true, color='black', label='True function')
  plt.plot(x, y_pred, color='red', label='Learned function')
  plt.xlabel('X')
  plt.ylabel('Y')
  ax.legend()
  ax.get_figure().savefig(f'{name}.{imgType}')
  plt.clf()   # clear figure
  plt.cla()   # clear axis
  plt.close() # close window


def plot_all(x, y_list, name, imgType='png'):
  fig, ax = plt.subplots()
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  label_list = ['True function', 'Learned function after Stage 1', 'Learned function after Stage 2']
  color_list = ['black', 'red', 'limegreen']
  plt.vlines(1, -0.9, 0.9, colors='blue', linestyles='dashed', label='Input boundary')
  for i in range(len(y_list)):
    plt.plot(x, y_list[i], color=color_list[i], label=label_list[i])
  ax.set_xlabel("X", fontsize=16)
  ax.set_ylabel('Y', fontsize=16)
  plt.yticks(size=11)
  plt.xticks(size=11)
  # ax.yaxis.set_label_coords(-0.18,0.5)
  # ax.xaxis.set_label_coords(0.0,0.5)
  # ax.set_xlim(cfg['xlim']['min'], cfg['xlim']['max'])
  # ax.set_ylim(cfg['ylim']['min'], cfg['ylim']['max'])
  # ax.locator_params(nbins=5, axis='x')
  # ax.locator_params(nbins=5, axis='y')
  # Set legend
  ax.legend(loc='lower left', frameon=False, fontsize=11)
  # Adjust to show y label
  # fig.subplots_adjust(left=0.2, bottom=0.14)
  ax.get_figure().savefig(f'{name}.{imgType}')
  plt.clf()   # clear figure
  plt.cla()   # clear axis
  plt.close() # close window

method = 'vanilla'
seed = 42
epoch = 1000
hidden_dim = 32
batch_size = 32
hidden_act = 'ReLU'
lr = 1e-2
set_random_seed(seed)
model = nn.Sequential(
    nn.Linear(1, hidden_dim),
    activations[hidden_act],
    nn.Linear(hidden_dim, hidden_dim),
    activations[hidden_act],
    nn.Linear(hidden_dim, 1)
  )
optim = torch.optim.Adam(model.parameters(), lr=lr)
# plot_model(model, f'intial_{method}_{batch_size}')

y_list = []
x_true = np.linspace(0, 2, 100)
y_true = np.sin(x_true*np.pi)
y_list.append(y_true)

for t in [1,2]:
  # Training on task i
  for i in range(epoch):
    x, y = data_generator(batch_size, task=t)
    # Compute predicted y
    y_pred = model(x)
    # Compute loss function
    loss = nn.MSELoss(reduction='mean')(y, y_pred)
    # Backpropagation
    optim.zero_grad()
    loss.backward()
    optim.step()
  # Collect data for later visualization
  with torch.no_grad():
    y_pred = model(to_tensor(x_true.reshape((100, -1)))).cpu().detach().numpy().reshape(-1)
    y_list.append(y_pred)

# Plot
plot_all(x_true, y_list, f'sin_{method}', imgType='png')
plot_all(x_true, y_list, f'sin_{method}', imgType='pdf')