import os
import torch
import numpy as np
from torch.distributions import Normal
import matplotlib.pyplot as plt
from toy_helpers import Data_Sampler, reward_fun

seed = 2023
LA = 10
lr = 3e-4
hidden_dim = 128
r_fun_std = 0.25
device = f"cuda:0" if torch.cuda.is_available() else "cpu"
eta = 1

# random seed
torch.manual_seed(seed)
np.random.seed(seed)

# load offline toy dataset
data = np.load("toy.npy", allow_pickle=True).item()
states = data["states"]
actions = data["actions"]
next_states = data["next_states"]
rewards = data["rewards"]
done = data["done"]

data_sampler = Data_Sampler(
    states.float(),
    actions.float(),
    rewards.float(),
    device,
)

# RL settings
state_dim = 2
action_dim = 2
max_action = 1.0
# diffusion settings
discount = 0.99
tau = 0.005
model_type = "MLP"
beta_schedule = "vp"

num_data = 5000
num_epochs = 1000
batch_size = 256
iterations = int(num_data / batch_size)
num_eval = 1000

# AWR
from awr.awr import ActorAgent

print("AWR")
agent = ActorAgent(
    state_dim,
    action_dim,
    discount,
    use_gae=True,
    use_cuda=False,
    use_noisy_net=False,
    use_continuous=True,
)


for i in range(1, 20000):
    states, actions, rewards, next_states, dones = data_sampler.full_sample(batch_size)
    agent.train_model(states, actions, rewards, next_states, dones)


new_state = torch.zeros((num_eval, 2), device=device)
new_action = agent.get_action(new_state)
np.save("awr.npy", new_action)

# DiffCPS

from agents.diffcps import DiffCPS as Agent

print("DiffCPS")
agent = Agent(
    state_dim=state_dim,
    action_dim=action_dim,
    max_action=max_action,
    device=device,
    discount=discount,
    tau=tau,
    max_q_backup=False,
    beta_schedule=beta_schedule,
    n_timesteps=15,
    LA=LA,
    lr=lr,
    lr_decay=True,
    lr_maxt=1000,
    grad_norm=4.0,
    policy_freq=1,
    target_kl=0,
    lambda_max=1000000,
    lambda_min=0,
)


for i in range(1, num_epochs + 1):
    loss_metric = agent.train(
        data_sampler, iterations=iterations, batch_size=batch_size
    )
    if i % 100 == 0:
        print(
            f'DiffCPS Epoch: {i} kl_loss {np.mean(loss_metric["kl_loss"])} lambda {np.mean(loss_metric["lambda"])}'
        )

new_state = torch.zeros((num_eval, 2), device=device)
new_action = agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("diffcps15.npy", new_action)


print("DiffCPS")
agent = Agent(
    state_dim=state_dim,
    action_dim=action_dim,
    max_action=max_action,
    device=device,
    discount=discount,
    tau=tau,
    max_q_backup=False,
    beta_schedule=beta_schedule,
    n_timesteps=25,
    LA=LA,
    lr=lr,
    lr_decay=True,
    lr_maxt=1000,
    grad_norm=4.0,
    policy_freq=1,
    target_kl=0,
    lambda_max=1000000,
    lambda_min=0,
)


for i in range(1, num_epochs + 1):
    loss_metric = agent.train(
        data_sampler, iterations=iterations, batch_size=batch_size
    )
    if i % 100 == 0:
        print(
            f'DiffCPS Epoch: {i} kl_loss {np.mean(loss_metric["kl_loss"])} lambda {np.mean(loss_metric["lambda"])}'
        )


new_state = torch.zeros((num_eval, 2), device=device)
new_action = agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("diffcps25.npy", new_action)


print("DiffCPS")
agent = Agent(
    state_dim=state_dim,
    action_dim=action_dim,
    max_action=max_action,
    device=device,
    discount=discount,
    tau=tau,
    max_q_backup=False,
    beta_schedule=beta_schedule,
    n_timesteps=30,
    LA=LA,
    lr=lr,
    lr_decay=True,
    lr_maxt=1000,
    grad_norm=4.0,
    policy_freq=1,
    target_kl=0,
    lambda_max=1000000,
    lambda_min=0,
)


for i in range(1, num_epochs + 1):
    loss_metric = agent.train(
        data_sampler, iterations=iterations, batch_size=batch_size
    )
    if i % 100 == 0:
        print(
            f'DiffCPS Epoch: {i} kl_loss {np.mean(loss_metric["kl_loss"])} lambda {np.mean(loss_metric["lambda"])}'
        )


new_state = torch.zeros((num_eval, 2), device=device)
new_action = agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("diffcps30.npy", new_action)

print("DiffCPS")
agent = Agent(
    state_dim=state_dim,
    action_dim=action_dim,
    max_action=max_action,
    device=device,
    discount=discount,
    tau=tau,
    max_q_backup=False,
    beta_schedule=beta_schedule,
    n_timesteps=50,
    LA=LA,
    lr=lr,
    lr_decay=True,
    lr_maxt=1000,
    grad_norm=4.0,
    policy_freq=1,
    target_kl=0,
    lambda_max=1000000,
    lambda_min=0,
)


for i in range(1, num_epochs + 1):
    loss_metric = agent.train(
        data_sampler, iterations=iterations, batch_size=batch_size
    )
    if i % 100 == 0:
        print(
            f'DiffCPS Epoch: {i} kl_loss {np.mean(loss_metric["kl_loss"])} lambda {np.mean(loss_metric["lambda"])}'
        )


new_state = torch.zeros((num_eval, 2), device=device)
new_action = agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("diffcps50.npy", new_action)


# Diffusion QL
from dql.ql_diffusion import Diffusion_QL

analytic_reward_f = None
print("DQL")
eta = 1
lr = 3e-4
agent = Diffusion_QL(
    state_dim=state_dim,
    action_dim=action_dim,
    max_action=max_action,
    device=device,
    discount=discount,
    tau=tau,
    eta=eta,
    beta_schedule=beta_schedule,
    n_timesteps=25,
    lr=lr,
)


for i in range(1, num_epochs + 1):
    loss_metric = agent.train(
        data_sampler, iterations=iterations, batch_size=batch_size
    )
    if i % 100 == 0:
        print(
            f'DQL Epoch: {i} actor_loss {np.mean(loss_metric["actor_loss"])} critic_loss {np.mean(loss_metric["critic_loss"])}'
        )

new_state = torch.zeros((num_eval, 2), device=device)
new_action = agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("dql25.npy", new_action)


agent = Diffusion_QL(
    state_dim=state_dim,
    action_dim=action_dim,
    max_action=max_action,
    device=device,
    discount=discount,
    tau=tau,
    eta=eta,
    beta_schedule=beta_schedule,
    n_timesteps=25,
    lr=lr,
)


for i in range(1, num_epochs + 1):
    loss_metric = agent.train(
        data_sampler, iterations=iterations, batch_size=batch_size
    )
    if i % 100 == 0:
        print(
            f'DQL Epoch: {i} actor_loss {np.mean(loss_metric["actor_loss"])} critic_loss {np.mean(loss_metric["critic_loss"])}'
        )

new_state = torch.zeros((num_eval, 2), device=device)
new_action = agent.actor.sample(new_state)
new_action = new_action.detach().cpu().numpy()
np.save("dql30.npy", new_action)
