import os
import torch
import numpy as np
from torch.distributions import Normal
import seaborn as sns
import matplotlib.pyplot as plt

from toy_helpers import Data_Sampler, reward_fun
from utils.logger import logger, setup_logger


def generate_data(num, device='cpu'):
    
    each_num = int(num / 4)
    pos = 0.8
    std = 0.05
    left_up_conor = Normal(torch.tensor([-pos, pos]), torch.tensor([std, std]))
    left_bottom_conor = Normal(torch.tensor([-pos, -pos]), torch.tensor([std, std]))
    right_up_conor = Normal(torch.tensor([pos, pos]), torch.tensor([std, std]))
    right_bottom_conor = Normal(torch.tensor([pos, -pos]), torch.tensor([std, std]))
    
    left_up_samples = left_up_conor.sample((each_num,)).clip(-1.0, 1.0)
    left_bottom_samples = left_bottom_conor.sample((each_num,)).clip(-1.0, 1.0)
    right_up_samples = right_up_conor.sample((each_num,)).clip(-1.0, 1.0)
    right_bottom_samples = right_bottom_conor.sample((each_num,)).clip(-1.0, 1.0)
    
    
    
    data = torch.cat([left_up_samples, left_bottom_samples, right_up_samples, right_bottom_samples], dim=0)

    action = data
    state = torch.zeros_like(action)
    reward = reward_fun(action)
    return Data_Sampler(state, action, reward, device)


device = 'cpu'
num_data = int(10000)
data_sampler = generate_data(num_data, device)

state_dim = 2
action_dim = 2
max_action = 1.0

discount = 0.99
tau = 0.005
model_type = 'MLP'

T = 50
beta_schedule = 'vp'
hidden_dim = 32
lr = 1e-3

num_epochs = 1000
batch_size = 100
iterations = int(num_data / batch_size)

results_dir = 'res'
setup_logger(os.path.basename(results_dir), variant=None, log_dir=results_dir)

from ql_diffusion import QL_Diffusion

agent = QL_Diffusion(state_dim=state_dim,
                     action_dim=action_dim,
                     max_action=max_action,
                     device=device,
                     discount=discount,
                     tau=tau,
                     beta_schedule=beta_schedule,
                     n_timesteps=T,
                     model_type=model_type,
                     hidden_dim=hidden_dim,
                     lr=lr)


for i in range(num_epochs):
    
    agent.train(data_sampler, iterations=iterations, batch_size=batch_size)
    logger.record_tabular('Trained Epochs', i)
    
    if i % 100 == 0:
        print(f'Epoch: {i}')
        logger.dump_tabular()

img_dir = 'imgs'
os.makedirs(img_dir, exist_ok=True)

num_eval = 2000

fig, ax = plt.subplots()
new_state = torch.zeros((num_eval, 2))
new_action = agent.actor.sample(new_state)
new_action = new_action.detach().numpy()
ax.scatter(new_action[:, 0], new_action[:, 1])
fig.save(os.path.join(img_dir, f'ql_diffusion_{T}.png'))




