import os
import torch
import numpy as np
from torch.distributions import Normal
import seaborn as sns
import matplotlib.pyplot as plt
import argparse
from toy_helpers import Data_Sampler, reward_fun

parser = argparse.ArgumentParser()
parser.add_argument("--eta", default=1, type=float)
parser.add_argument("--seed", default=2023, type=int)
parser.add_argument("--lr", default=3e-4, type=float)
parser.add_argument("--device", default=0, type=int)
parser.add_argument("--hd", default=128, type=int)
parser.add_argument("--x", default=0.0, type=float)
parser.add_argument("--y", default=0.0, type=float)
parser.add_argument("--dir", default="ablation", type=str)
parser.add_argument("--r_fun", default="no", type=str)
parser.add_argument("--mode", default="whole_grad", type=str)
args = parser.parse_args()

seed = args.seed
eta = args.eta
lr = args.lr
hidden_dim = args.hd

r_fun_std = 0.25
device = 'cuda:0'


torch.manual_seed(seed)
np.random.seed(seed)

data = np.load("toy.npy", allow_pickle=True).item()
states = data["states"]
actions = data["actions"]
next_states = data["next_states"]
rewards = data["rewards"]
done = data["done"]

data_sampler = Data_Sampler(
    states.float(),
    actions.float(),
    rewards.float(),
    device,
)
# seed = 2023


# # device = 'cuda:1'
# num_data = int(5000)
# data_sampler = generate_data(num_data, device)


state_dim = 2
action_dim = 2
max_action = 1.0

discount = 0.99
tau = 0.005
model_type = "MLP"

# T = 50
beta_schedule = "vp"

# lr = 3e-4
# eta = 3.0
num_data = 5000
num_epochs = 1000
batch_size = 256
iterations = int(num_data / batch_size)


num_eval = 1000

# Plot MLE BC
from bc_mle import BC_MLE as MLE_Agent
mle_agent = MLE_Agent(state_dim=state_dim,
                      action_dim=action_dim,
                      max_action=max_action,
                      device=device,
                      discount=discount,
                      tau=tau,
                      lr=lr,
                      hidden_dim=32)


for i in range(num_epochs):
    
    mle_agent.train(data_sampler,
                    iterations=iterations,
                    batch_size=batch_size)
    
    if i % 100 == 0:
        print(f'Epoch: {i}')

new_state = torch.zeros((num_eval, 2), device=device)
new_action = mle_agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("mle.npy", new_action)


# Plot CVAE BC
from bc_cvae import BC_CVAE as CVAE_Agent
cvae_agent = CVAE_Agent(state_dim=state_dim,
                        action_dim=action_dim,
                        max_action=max_action,
                        device=device,
                        discount=discount,
                        tau=tau,
                        lr=lr,
                        hidden_dim=32)


for i in range(num_epochs):
    
    cvae_agent.train(data_sampler,
                     iterations=iterations,
                     batch_size=batch_size)
    
    if i % 100 == 0:
        print(f'Epoch: {i}')

new_state = torch.zeros((num_eval, 2), device=device)
new_action = cvae_agent.vae.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("cave.npy", new_action)


from bc_mmd import BC_MMD as MMD_Agent

mmd_agent =  MMD_Agent(state_dim=state_dim,
                       action_dim=action_dim,
                       max_action=max_action,
                       device=device,
                       discount=discount,
                       tau=tau,
                       lr=lr,
                       hidden_dim=32)

for i in range(num_epochs):

    mmd_agent.train(data_sampler,
                    iterations=iterations,
                    batch_size=batch_size)

    if i % 100 == 0:
        print(f'Epoch: {i}')

new_state = torch.zeros((num_eval, 2), device=device)
new_action = mmd_agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("mmd.npy", new_action)


from bc_diffusion import BC as Diffusion_Agent
diffusion_agent = Diffusion_Agent(state_dim=state_dim,
                                  action_dim=action_dim,
                                  max_action=max_action,
                                  device=device,
                                  discount=discount,
                                  tau=tau,
                                  beta_schedule=beta_schedule,
                                  n_timesteps=15,
                                  model_type=model_type,
                                  hidden_dim=128,
                                  lr=lr)

for i in range(num_epochs):
    
    diffusion_agent.train(data_sampler,
                          iterations=iterations,
                          batch_size=batch_size)
    
    if i % 100 == 0:
        print(f'Epoch: {i}')

new_state = torch.zeros((num_eval, 2), device=device)
new_action = diffusion_agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("diff15.npy", new_action)

diffusion_agent = Diffusion_Agent(state_dim=state_dim,
                                  action_dim=action_dim,
                                  max_action=max_action,
                                  device=device,
                                  discount=discount,
                                  tau=tau,
                                  beta_schedule=beta_schedule,
                                  n_timesteps=15,
                                  model_type=model_type,
                                  hidden_dim=128,
                                  lr=lr)

for i in range(num_epochs):
    
    diffusion_agent.train(data_sampler,
                          iterations=iterations,
                          batch_size=batch_size)
    
    if i % 100 == 0:
        print(f'Epoch: {i}')

new_state = torch.zeros((num_eval, 2), device=device)
new_action = diffusion_agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("diff25.npy", new_action)

diffusion_agent = Diffusion_Agent(state_dim=state_dim,
                                  action_dim=action_dim,
                                  max_action=max_action,
                                  device=device,
                                  discount=discount,
                                  tau=tau,
                                  beta_schedule=beta_schedule,
                                  n_timesteps=15,
                                  model_type=model_type,
                                  hidden_dim=128,
                                  lr=lr)

for i in range(num_epochs):
    
    diffusion_agent.train(data_sampler,
                          iterations=iterations,
                          batch_size=batch_size)
    
    if i % 100 == 0:
        print(f'Epoch: {i}')

new_state = torch.zeros((num_eval, 2), device=device)
new_action = diffusion_agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("diff30.npy", new_action)

diffusion_agent = Diffusion_Agent(state_dim=state_dim,
                                  action_dim=action_dim,
                                  max_action=max_action,
                                  device=device,
                                  discount=discount,
                                  tau=tau,
                                  beta_schedule=beta_schedule,
                                  n_timesteps=15,
                                  model_type=model_type,
                                  hidden_dim=128,
                                  lr=lr)

for i in range(num_epochs):
    
    diffusion_agent.train(data_sampler,
                          iterations=iterations,
                          batch_size=batch_size)
    
    if i % 100 == 0:
        print(f'Epoch: {i}')

new_state = torch.zeros((num_eval, 2), device=device)
new_action = diffusion_agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("diff50.npy", new_action)
