import os
import torch
import numpy as np
from torch.distributions import Normal
import seaborn as sns
import matplotlib.pyplot as plt

from toy_helpers import Data_Sampler, reward_fun


def generate_data(num, device = 'cpu'):
    
    each_num = int(num / 4)
    pos = 0.8
    std = 0.05
    left_up_conor = Normal(torch.tensor([-pos, pos]), torch.tensor([std, std]))
    left_bottom_conor = Normal(torch.tensor([-pos, -pos]), torch.tensor([std, std]))
    right_up_conor = Normal(torch.tensor([pos, pos]), torch.tensor([std, std]))
    right_bottom_conor = Normal(torch.tensor([pos, -pos]), torch.tensor([std, std]))
    
    left_up_samples = left_up_conor.sample((each_num,)).clip(-1.0, 1.0)
    left_bottom_samples = left_bottom_conor.sample((each_num,)).clip(-1.0, 1.0)
    right_up_samples = right_up_conor.sample((each_num,)).clip(-1.0, 1.0)
    right_bottom_samples = right_bottom_conor.sample((each_num,)).clip(-1.0, 1.0)
    
    
    
    data = torch.cat([left_up_samples, left_bottom_samples, right_up_samples, right_bottom_samples], dim=0)

    action = data
    state = torch.zeros_like(action)
    reward = reward_fun(action)
    return Data_Sampler(state, action, reward, device)


device = 'cuda:0'
num_data = int(10000)
data_sampler = generate_data(num_data, device)

state_dim = 2
action_dim = 2
max_action = 1.0

discount = 0.99
tau = 0.005
model_type = 'MLP'

T = 50
beta_schedule = 'vp'
hidden_dim = 32
lr = 1e-3

num_epochs = 1000
batch_size = 100
iterations = int(num_data / batch_size)

img_dir = 'toy_imgs'
os.makedirs(img_dir, exist_ok=True)
fig, axs = plt.subplots(1, 5, figsize=(6.5 * 5, 5))

# Plot the ground truth
num_eval = 2000
_, action_samples, _ = data_sampler.sample(num_eval)
action_samples = action_samples.numpy()
axs[0].scatter(action_samples[:, 0], action_samples[:, 1], alpha=0.3)
axs[0].set_xlim(-1., 1.)
axs[0].set_ylim(-1., 1.)
axs[0].set_xlabel('x', fontsize=20)
axs[0].set_ylabel('y', fontsize=20)
axs[0].set_title('Ground Truth', fontsize=25)


# Plot MLE BC
from bc_mle import BC_MLE as MLE_Agent
mle_agent = MLE_Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=discount,
                      tau=tau,
                      lr=lr,
                      hidden_dim=32)


for i in range(num_epochs):
    
    mle_agent.train(data_sampler,
                    iterations=iterations,
                    batch_size=batch_size)
    
    if i % 100 == 0:
        print(f'Epoch: {i}')

new_state = torch.zeros((num_eval, 2))
new_action = mle_agent.actor.sample(new_state)
new_action = new_action.detach().numpy()
axs[1].scatter(new_action[:, 0], new_action[:, 1], alpha=0.3)
axs[1].set_xlim(-1., 1.)
axs[1].set_ylim(-1., 1.)
axs[1].set_xlabel('x', fontsize=20)
axs[1].set_ylabel('y', fontsize=20)
axs[1].set_title('BC-MLE', fontsize=25)


# Plot CVAE BC
from bc_cvae import BC_CVAE as CVAE_Agent
cvae_agent = CVAE_Agent(state_dim=state_dim,
                        action_dim=action_dim,
                        max_action=max_action,
                        device=device,
                        discount=discount,
                        tau=tau,
                        lr=lr,
                        hidden_dim=32)


for i in range(num_epochs):
    
    cvae_agent.train(data_sampler,
                     iterations=iterations,
                     batch_size=batch_size)
    
    if i % 100 == 0:
        print(f'Epoch: {i}')

new_state = torch.zeros((num_eval, 2))
new_action = cvae_agent.vae.sample(new_state)
new_action = new_action.detach().numpy()
axs[2].scatter(new_action[:, 0], new_action[:, 1], alpha=0.3)
axs[2].set_xlim(-1., 1.)
axs[2].set_ylim(-1., 1.)
axs[2].set_xlabel('x', fontsize=20)
axs[2].set_ylabel('y', fontsize=20)
axs[2].set_title('BC-CVAE', fontsize=25)


# Plot CVAE BC
from bc_mmd import BC_MMD as MMD_Agent

mmd_agent =  MMD_Agent(state_dim=state_dim,
                       action_dim=action_dim,
                       max_action=max_action,
                       device=device,
                       discount=discount,
                       tau=tau,
                       lr=lr,
                       hidden_dim=32)

for i in range(num_epochs):

    mmd_agent.train(data_sampler,
                    iterations=iterations,
                    batch_size=batch_size)

    if i % 100 == 0:
        print(f'Epoch: {i}')

new_state = torch.zeros((num_eval, 2))
new_action = mmd_agent.actor.sample(new_state)
new_action = new_action.detach().numpy()
axs[3].scatter(new_action[:, 0], new_action[:, 1], alpha=0.3)
axs[3].set_xlim(-1., 1.)
axs[3].set_ylim(-1., 1.)
axs[3].set_xlabel('x', fontsize=20)
axs[3].set_ylabel('y', fontsize=20)
axs[3].set_title('BC-MMD', fontsize=25)


# Plot Diffusion BC
from bc_diffusion import BC as Diffusion_Agent
diffusion_agent = Diffusion_Agent(state_dim=state_dim,
                                  action_dim=action_dim,
                                  max_action=max_action,
                                  device=device,
                                  discount=discount,
                                  tau=tau,
                                  beta_schedule=beta_schedule,
                                  n_timesteps=T,
                                  model_type=model_type,
                                  hidden_dim=32,
                                  lr=lr)

for i in range(num_epochs):
    
    diffusion_agent.train(data_sampler,
                          iterations=iterations,
                          batch_size=batch_size)
    
    if i % 100 == 0:
        print(f'Epoch: {i}')

new_state = torch.zeros((num_eval, 2))
new_action = diffusion_agent.actor.sample(new_state)
new_action = new_action.detach().numpy()
axs[4].scatter(new_action[:, 0], new_action[:, 1], alpha=0.3)
axs[4].set_xlim(-1., 1.)
axs[4].set_ylim(-1., 1.)
axs[4].set_xlabel('x', fontsize=20)
axs[4].set_ylabel('y', fontsize=20)
axs[4].set_title('BC-Diffusion', fontsize=25)


fig.save(os.path.join(img_dir, f'bc_all_{T}.pdf'))
