#!/usr/bin/env python
# coding: utf-8

# In[14]:


import dsrl
import torch
import numpy as np
import os, sys

proj_direc = "/SafeOfflineRL/OSRL/"
# proj_direc = os.path.abspath(__file__)[:os.path.abspath(__file__).rindex('OSRL')+5]
sys.path.append(proj_direc)

# /*** import regarding diffusion model ***/
import data_augmentation.diffuser.utils as diffuser_utils

from collections import Counter, defaultdict

from data_augmentation.diffuser.datasets.sequence import TrajectoryDataset
from data_augmentation.diffuser.config.default_config import Diffuser_Config

import numpy as np
import torch, json
from torch.nn import functional as F

from argparse import Namespace
from datetime import datetime



# In[16]:


# NOTE: Define useful functions

def load_diffusion(trainer, model_path, load_ema=False):
    data = torch.load(model_path)
    trainer.step = data['step']
    if load_ema:
        trainer.model.load_state_dict(data['ema'])
    else:
        trainer.model.load_state_dict(data['model'])
    trainer.ema_model.load_state_dict(data['ema'])

def load_diffusion_model(diffusion, model_path):
    data = torch.load(model_path)
    diffusion.load_state_dict(data['model'])


# In[17]:


# NOTE: 基于数据框架加载数据，定义训练器，做diffusion的训练，注意模型的保存位置
# 数据加载
# 数据框架规范
# 输入训练器来做训练
from argparse import Namespace

surrogate_args = Namespace(**{
    "task": "OfflineCarCircle-v0",
    "cost_limit": 10,
    "gamma": 0.99,
    "frontier_ratio": 0.1,
})
import gymnasium as gym
if "Metadrive" in surrogate_args.task:
    import gym
import dsrl
env = gym.make(surrogate_args.task)


# In[ ]:


while True:
    success_load = True
    try:
        data = env.get_dataset()
    except:
        print('Fail to load data... One time')
        success_load = False
    if success_load:
        break
env.set_target_cost(surrogate_args.cost_limit)

def print_dict_shape(dic):
    for _key in dic:
        print('_key:', _key, dic[_key].shape)


def process_trajectory_dataset(dataset: dict, cost_limit: float, gamma: float, frontier_ratio: float, task_name: str):
    # Static Variable
    QualifiedDataRatio = 0.5
    MaxOptimalDataNum = 15
    OptimalDataLen = 10
    
    done_idx = np.where(
        (dataset["terminals"] == 1) | (dataset["timeouts"] == 1)
    )[0]
    trajs, cost_returns, reward_returns = [], [], []
    for i in range(done_idx.shape[0]):
        start = 0 if i == 0 else done_idx[i - 1] + 1
        end = done_idx[i] + 1
        cost_return = np.sum(dataset["costs"][start:end])
        reward_return = np.sum(dataset["rewards"][start:end])
        traj = {k: dataset[k][start:end] for k in dataset.keys()}
        trajs.append(traj)
        cost_returns.append(cost_return)
        reward_returns.append(reward_return)
    
    n_trajs = len(trajs)
    cost_returns, reward_returns = np.asarray(cost_returns), np.asarray(reward_returns)
    # Filter out the data with very-low quality
    cmin, cmax = np.min(cost_returns), np.max(cost_returns)
    rmin, rmax = np.min(reward_returns), np.max(reward_returns)
    quality_score = 1 - (cost_returns - cmin) / (cmax - cmin) + (reward_returns - rmin) / (rmax - rmin)
    qualified_idx = np.argsort(-quality_score)[:int(n_trajs * QualifiedDataRatio)].tolist()
    # Select the safe optimal trajectories
    safe_idx = np.where(cost_returns < cost_limit)[0]   # Normal
    # safe_idx = np.where(cost_returns < cost_limit * 0.9)[0] # Strong
    optimal_num = min(int(len(safe_idx) * frontier_ratio), MaxOptimalDataNum)
    optimal_idx = sorted(safe_idx, key=lambda x: -reward_returns[x])[:optimal_num]
    
    qualified_data = defaultdict(list)
    for k in dataset.keys():
        for i in qualified_idx:
            if i not in optimal_idx:
                qualified_data[k].append(trajs[i][k])
    qualified_data = {
        k: np.stack(v, axis=0)  # NOTE: stack is important
        for k, v in qualified_data.items()
    }
    
    optimal_data = defaultdict(list)
    for k in dataset.keys():
        for i in optimal_idx:
            optimal_data[k].append(trajs[i][k])
    optimal_data = {
        k: np.stack(v, axis=0)  # NOTE: stack is important
        for k, v in optimal_data.items()
    }
        
    return optimal_data, qualified_data


optimal_data, qualified_data = process_trajectory_dataset(data, surrogate_args.cost_limit, surrogate_args.gamma, surrogate_args.frontier_ratio, surrogate_args.task)
data = {
    _k: np.concatenate([optimal_data[_k], qualified_data[_k]], axis=0) for _k in optimal_data
}
"""
_key: observations (725, 300, 8)
_key: rewards (725, 300)
_key: terminals (725, 300)
_key: timeouts (725, 300)
_key: actions (725, 300, 2)
_key: costs (725, 300)
_key: next_observations (725, 300, 8)
"""
data['returns'] = data['rewards'].sum(axis=-1)[..., None]
data['traj_costs'] = data['costs'].sum(axis=-1)[..., None]

np.savez('carcircle_data.npz', **data)
assert 0

# In[ ]:


dataset = TrajectoryDataset()
dataset.add_data(data['observations'], data['actions'], data['returns'], data['traj_costs'])

# Define data loader
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=32, num_workers=0, shuffle=True, pin_memory=True
)


# In[ ]:


diffuser_config = dict(vars(Diffuser_Config))
diffuser_keys = list(diffuser_config.keys())
for _k in diffuser_keys:
    if _k.startswith('__'):
        del diffuser_config[_k]
diffuser_config = Namespace(**diffuser_config)
diffuser_config.rtg_dim = 1
diffuser_config.ctg_dim = 1

# TODO: determine transition dim!!!
diffuser_config.observation_dim = data['observations'].shape[-1]
diffuser_config.action_dim = data['actions'].shape[-1]
diffuser_config.transition_dim = diffuser_config.observation_dim + diffuser_config.action_dim
diffuser_config.horizon = data['observations'].shape[-2]

diffuser_config.n_train_steps = 100000
diffuser_config.n_steps_per_epoch = 10000

specific_word = surrogate_args.task
diffuser_config.bucket = f"/SafeOfflineRL/logs/diffusion_models/{specific_word}"

# /* --- Define the diffusion model --- */
backbone_config = diffuser_utils.Config(
    diffuser_config.backbone,
    savepath='diffusion_backbone.pkl',
    horizon=diffuser_config.horizon,
    transition_dim=diffuser_config.transition_dim,
    rtg_dim=diffuser_config.rtg_dim,
    ctg_dim=diffuser_config.ctg_dim,
    dim_mults=diffuser_config.dim_mults,
    input_condition=diffuser_config.input_condition,
    dim=diffuser_config.dim,
    condition_dropout=diffuser_config.condition_dropout,
    calc_energy=diffuser_config.calc_energy,
    device=diffuser_config.device,
)
diffusion_config = diffuser_utils.Config(
    diffuser_config.diffusion,
    savepath='diffusion_config.pkl',
    horizon=diffuser_config.horizon,
    observation_dim=diffuser_config.observation_dim,
    action_dim=diffuser_config.action_dim,
    n_timesteps=diffuser_config.n_diffusion_steps,
    loss_type=diffuser_config.loss_type,
    clip_denoised=diffuser_config.clip_denoised,
    predict_epsilon=diffuser_config.predict_epsilon,
    ## loss weighting
    action_weight=diffuser_config.action_weight,
    loss_weights=diffuser_config.loss_weights,
    loss_discount=diffuser_config.loss_discount,
    input_condition=diffuser_config.input_condition,
    condition_guidance_w=diffuser_config.condition_guidance_w,
    device=diffuser_config.device,
)
trainer_config = diffuser_utils.Config(
    diffuser_utils.Trainer,
    savepath='trainer_config.pkl',
    train_batch_size=diffuser_config.batch_size,
    train_lr=diffuser_config.learning_rate,
    gradient_accumulate_every=diffuser_config.gradient_accumulate_every,
    ema_decay=diffuser_config.ema_decay,
    sample_freq=diffuser_config.sample_freq,
    save_freq=diffuser_config.save_freq,
    log_freq=diffuser_config.log_freq,
    label_freq=int(diffuser_config.n_train_steps // diffuser_config.n_saves),
    save_parallel=diffuser_config.save_parallel,
    bucket=diffuser_config.bucket,
    n_reference=diffuser_config.n_reference,
    train_device=diffuser_config.device,
    save_checkpoints=diffuser_config.save_checkpoints,
)

diffusion_backbone = backbone_config()
diffusion = diffusion_config(diffusion_backbone)

time_str = datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
diffuser_trainer = trainer_config(diffusion, dataset)

diffusion_path = "/SafeOfflineRL/logs/diffusion_models/OfflineCarCircle-v0/checkpoint/state_79999.pt"
load_diffusion(diffuser_trainer, diffusion_path)

evaluation_loss = diffuser_trainer.metrics_test(data_loader)
print('evaluation loss:', evaluation_loss)

# # NOTE: Here; Below is the training code
# # for train_iter in range(int(diffuser_config.n_train_steps // diffuser_config.n_steps_per_epoch * 1.5)):
# for train_iter in range(int(diffuser_config.n_train_steps // diffuser_config.n_steps_per_epoch)):
#     diffuser_trainer.train(diffuser_config.n_steps_per_epoch)
#     tot_loss = []
#     tot_loss.append(diffuser_trainer.metrics_test(data_loader))
#     print('Computed losses:', tot_loss)
#     with open(f'{time_str}.json', 'a+') as f:
#         f.write(str(tot_loss) + '\n')

diffuser_trainer.model

# In[ ]:


assert 0

import umap

reducer = umap.UMAP(n_components=2)
embedding = reducer.fit_transform(vectors_array)


# In[ ]:


import matplotlib.pyplot as plt

plt.figure(figsize=(8, 6))
num_vectors = len(total_generated_samples)

num_samples = 100

labelings = ['generated-0', 'generated-1', 'generated-2', 'generated-3',
            '299999', '59999', '99999', '199999', 'expert']

for i in range(num_vectors):
    start_index = i * num_samples
    end_index = start_index + num_samples
    labeling = f'Generated-{i}' if i != num_vectors - 1 else 'Expert'
    labeling = labelings[i]
    plt.scatter(embedding[start_index:end_index, 0], embedding[start_index:end_index, 1], s=10, label=labeling)

plt.title('UMAP Visualization of Vectors with Different Colors')
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.colorbar()
plt.tight_layout()
plt.legend()
plt.savefig('Output.pdf')


# In[ ]:




