import numpy as np
import os
import torch
import matplotlib.pyplot as plt
from toy_dataset import gaussian_reward, action_sampler

from utils import utils
from utils.data_sampler import Data_Sampler
from agents.eql_diffusion import Diffusion_EQL as QL_Agent
from agents.bc_diffusion import Diffusion_BC as BC_Agent
from tqdm import tqdm
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("-m", "--method", default=0, type=int)
args = parser.parse_args()

names = ['entropy-qensemble', 'entropy-sde', 'sde']
name = names[args.method]

seed = 661
np.random.seed(seed)
torch.manual_seed(seed)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
n_samples = 1000
num_epochs = 500
num_iters = 200
ent_coef = 2. if name != 'sde' else 0

model_dir = name
os.makedirs(model_dir, exist_ok=True)

num_critics = 32 if name == 'entropy-qensemble' else 1

########################################################

xs = np.linspace(-1, 1, 101)
reward_function = gaussian_reward(xs)

state1 = np.zeros(n_samples)
action1 = action_sampler(xs, num_samples=n_samples)
state2 = np.clip(state1 + action1, -1, 1)
reward1 = gaussian_reward(state2)
reward1_approx = reward1 + np.random.randn(n_samples) * 0.03 + 0.02 # for better visualization

action2 = np.random.normal(0, 0.2, n_samples) * 0.2
state3 = np.clip(state2 + action2, -1, 1)
reward2 = gaussian_reward(state3)
reward2_approx = reward2 + np.random.randn(n_samples) * 0.03 - 0.02 # for better visualization

observations = np.expand_dims(np.concatenate([state1, state2]), axis=1)
actions = np.expand_dims(np.concatenate([action1, action2]), axis=1)
rewards = np.expand_dims(np.concatenate([reward1, reward2]), axis=1)
next_observations = np.expand_dims(np.concatenate([state2, state3]), axis=1)
terminals = np.expand_dims(np.concatenate([np.zeros_like(state1), np.ones_like(state2)]), axis=1)


dataset = {'observations': observations, 
           'actions': actions, 
           'rewards': rewards, 
           'next_observations': next_observations, 
           'terminals': terminals}

#########################################################

def eval_policy(agent, num_evals=10):

    def step(s, a):
        s_next = np.clip(s + a, -1, 1)
        r = gaussian_reward(s_next)
        return s_next, r

    rewards = []
    s1 = np.array([0])
    for n in range(num_evals):
        a1 = agent.sample_action(s1)
        s2, r1 = step(s1, a1)
        a2 = agent.sample_action(s2)
        s3, r2 = step(s2, a2)
        rewards.append(r1 + r2)

    return np.mean(rewards)

#########################################################

data_sampler = Data_Sampler(dataset, device, 'no')
utils.print_banner(f'Model: {model_dir} - Loaded buffer: {len(terminals)} - ent_coef: {ent_coef}')

total_rewards = []
for _ in range(6):
    agent = QL_Agent(state_dim=1, action_dim=1, max_action=1., device=device, discount=0.99, tau=0.005,
                     n_timesteps=5, lr_decay=True, lr_maxt=num_epochs, ent_coef=ent_coef,
                     num_critics=num_critics, loss_type='NML', action_clip=True)

    rewards = []
    for i in tqdm(range(num_epochs)):
        agent.train(data_sampler, iterations=num_iters, batch_size=64)
        eval_reward = eval_policy(agent)
        rewards.append(eval_reward)

total_rewards.append(rewards)
agent.save_model(model_dir)

np.save(f'{model_dir}/rewards.npy', np.array(total_rewards))

#########################################################
# fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, sharex=False, sharey=False, figsize=(10, 4))

plt.figure(figsize=(5, 4))
plt.plot(xs, reward_function, color='#696969', alpha=0.2, label='Reward function')
plt.scatter(state2, reward1_approx, s=10, marker='x', color='#4682B4', alpha=0.4, label='Step-1 samples')
plt.scatter(state3, reward2_approx, s=10, color='#FF6247', alpha=0.2, label='Step-2 samples')

plt.legend(fontsize=12, loc='upper right')
plt.ylabel('Reward', fontsize=14)
plt.xlabel('State', fontsize=14)
plt.title('Offline dataset', fontsize=14)

plt.tight_layout()
plt.savefig(f'{model_dir}/data_samples.pdf')
#########################################################

plt.figure(figsize=(4, 3))

plt.plot(np.array(rewards))
plt.ylabel('Average Return', fontsize=14)
plt.xlabel('Epoch', fontsize=14)

plt.tight_layout()
plt.savefig(f'{model_dir}/training_curve.pdf')
# plt.show()