import argparse

import os

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# from examples.configs.bc_configs import BC_DEFAULT_CONFIG, BCTrainConfig
from osrl.algorithms import BC, BCTrainer
from osrl.common import TransitionDataset
from osrl.common.dataset import process_bc_dataset, process_realworld_dataset
from osrl.common.exp_util import auto_name, seed_all


ROBOMIMIC = ['lift', 'can', 'square', 'tool_hang', 'transport']

def boolean(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def gradient_reversal_hook(module, grad_input, grad_output):
    # 反转权重梯度，grad_input[1] 对应于权重的梯度
    new_grad_input = (grad_input[0], -1.0 * grad_input[1])
    if len(grad_input) > 2:  # 如果有偏置的梯度，也反转它
        new_grad_input += (-1.0 * grad_input[2],)
    return new_grad_input

# Parameters
parser = argparse.ArgumentParser()
# Define my dataset
parser.add_argument('--project', default='OSRL-baselines', type=str)
parser.add_argument('--group', default=None, type=str)
parser.add_argument('--name', default=None, type=str)
parser.add_argument('--suffix', default='', type=str)
parser.add_argument('--logdir', default='logs', type=str)
parser.add_argument('--verbose', default=True, type=bool)
parser.add_argument('--task', default='OfflineCarCircle-v0', type=str)
parser.add_argument('--cost_limit', default=10, type=int)
parser.add_argument('--frontier_ratio', default=0.1, type=float)
# Below is the old version
parser.add_argument('--dataset', default='medium-replay', type=str)
parser.add_argument('--state', default=False, type=boolean)
parser.add_argument('--mismatch', default=False, type=boolean)
parser.add_argument('--algo_type', default='smodice', type=str)
parser.add_argument('--disc_type', default='learned', type=str)
parser.add_argument('--gamma', default=0.99, type=float)

parser.add_argument('--num_expert_traj', default=0, type=int)
parser.add_argument('--num_offline_traj', default=2000, type=int)
parser.add_argument('--total_iterations', default=int(1e6), type=int)
parser.add_argument('--disc_iterations', default=int(1e3), type=int)
parser.add_argument('--log_iterations', default=int(5e3), type=int)
parser.add_argument('--episodes', default=10, type=int)

parser.add_argument('--bc_only', default=False, type=boolean)
parser.add_argument('--actor_deterministic', default=True, type=boolean)
parser.add_argument('--absorbing_state', default=True, type=boolean)
parser.add_argument('--standardize_reward', default=True, type=boolean)
parser.add_argument('--standardize_obs', default=True, type=boolean)
parser.add_argument('--reward_type', default='P', type=str, help='choose from T/P/C')
parser.add_argument('--reward_scale', default=1, type=float)
parser.add_argument('--res_scale', default=3, type=float)
parser.add_argument('--mean_range', default=(-7.24, 7.24))
parser.add_argument('--logstd_range', default=(-5., 2.))

parser.add_argument('--hidden_sizes', default=(256, 256))
parser.add_argument('--num_hidden', default=264, type=int)
parser.add_argument('--batch_size', default=512, type=int)
parser.add_argument('--f', default='kl', type=str)
parser.add_argument('--lr', default=3e-4, type=float)
parser.add_argument('--actor_lr', default=3e-4, type=float)
parser.add_argument('--lr_ratio', default=0.001, type=float)
parser.add_argument('--v_l2_reg', default=0.0001, type=float)
parser.add_argument('--r_l2_reg', default=0.0001, type=float)
parser.add_argument('--alpha', default=0.5, type=float)
parser.add_argument('--use_policy_entropy_constraint', default=True, type=boolean)
parser.add_argument('--target_entropy', default=None, type=float)

parser.add_argument('--device', default='cuda:0', type=str)
parser.add_argument('--wandb', default=True, type=boolean)
parser.add_argument('--render', default=False, type=boolean)
parser.add_argument('--seed', default=0, type=int)

parser.add_argument('--idx', default=0, type=int)

args = parser.parse_args()

seed_all(args.seed)
if args.device == "cpu":
    torch.set_num_threads(args.threads)

import gymnasium as gym  # noqa
if "Metadrive" in args.task:
    import gym
env = gym.make(args.task)
while True:
    success_load = True
    try:
        data = env.get_dataset()
    except:
        print('Fail to load data... One time')
        success_load = False
    if success_load:
        break
env.set_target_cost(args.cost_limit)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]    
max_ep_len = 300
batch_size = 8

# obtain the real data    
optimal_data, qualified_data = process_realworld_dataset(data, args.cost_limit, args.gamma, args.frontier_ratio, args.task)
num_expert, num_offline = len(optimal_data['observations']), len(qualified_data['observations'])
# expert_state_d = torch.as_tensor(optimal_data['observations'].reshape(num_expert//max_ep_len, max_ep_len, state_dim)).to(args.device).cpu().numpy()
# expert_action_d = torch.as_tensor(optimal_data['actions'].reshape(num_expert//max_ep_len, max_ep_len, action_dim)).to(args.device).cpu().numpy()
# offline_state_d = torch.as_tensor(qualified_data['observations'].reshape(num_offline//max_ep_len, max_ep_len, state_dim)).to(args.device).cpu().numpy()
# offline_action_d = torch.as_tensor(qualified_data['actions'].reshape(num_offline//max_ep_len, max_ep_len, action_dim)).to(args.device).cpu().numpy()
expert_state_d   = optimal_data['observations'].reshape(num_expert//max_ep_len, max_ep_len, state_dim)
expert_action_d  = optimal_data['actions'].reshape(num_expert//max_ep_len, max_ep_len, action_dim)
expert_reward_v = optimal_data['rewards'].reshape(num_expert//max_ep_len, max_ep_len).sum(-1)
expert_cost_v = optimal_data['costs'].reshape(num_expert//max_ep_len, max_ep_len).sum(-1)

offline_state_d  = qualified_data['observations'].reshape(num_offline//max_ep_len, max_ep_len, state_dim)
offline_action_d = qualified_data['actions'].reshape(num_offline//max_ep_len, max_ep_len, action_dim)
offline_reward_v = qualified_data['rewards'].reshape(num_offline//max_ep_len, max_ep_len).sum(-1)
offline_cost_v = qualified_data['costs'].reshape(num_offline//max_ep_len, max_ep_len).sum(-1)

def refresh_model():
    # model_path = "logs/novar-rew100-1-lr/transformer_disc_69.pth"
    # model_path = "logs/novar-rew100-1-lr/transformer_disc_244.pth"
    # model_path = "logs/novar-rew100-1-lr/transformer_disc_34.pth"
    # model_path = "logs/novar-rew100-1-lr-shuffle/transformer_disc_29.pth"
    # model_path = "logs/novar-rew100-1-lr/transformer_disc_169.pth"
    # model_path = "logs/novar-rew100-1-lr-shuffle/transformer_disc_109.pth"
    # model_path = "logs/novar-rew200-1-lr-random-pt2/transformer_disc_19.pth"
    # model_path = "logs/novar-rew300-5-llr-random-pt2/transformer_disc_29.pth"
    model_path = "logs/var-random/transformer_disc_19.pth"
    
    transformer = torch.load(model_path)
    print('Load Successfully!!!')
    return transformer

import GPUtil

def cur_mem():
    gpus = GPUtil.getGPUs()
    return gpus[0].memoryUsed

p_timesteps = torch.arange(0, max_ep_len).unsqueeze(0).repeat(len(offline_state_d), 1).to(args.device).cpu().numpy()
e_timesteps = torch.arange(0, max_ep_len).unsqueeze(0).repeat(len(expert_state_d), 1).to(args.device).cpu().numpy()

def tensor_size_in_bytes(tensor):
    return tensor.numel() * tensor.element_size()

def iter_predict(state, action, timesteps, go=True):
    def to_tensor(arr):
        return torch.tensor(arr).to(args.device)
    tot_rew = []
    _bs = 10
    for i in range(0, len(state), _bs):
        transformer = refresh_model()
        transformer.eval()
        if go:
            print(f'{i}/{len(state)//_bs}', end=',')
            print(cur_mem())
        tmp_state, tmp_action, tmp_timesteps = to_tensor(state[i:i+_bs]), to_tensor(action[i:i+_bs]), to_tensor(timesteps[i:i+_bs])
        res, _ = transformer(tmp_state, tmp_action, tmp_timesteps)
        print('size', tensor_size_in_bytes(res))
        tot_rew.append(res)
        del transformer
        del tmp_state, tmp_action, tmp_timesteps
        torch.cuda.empty_cache()
    return torch.cat(tot_rew, dim=0)

seg_size = 30

_i = args.idx
expert_rew = iter_predict(expert_state_d, expert_action_d, e_timesteps)
offline_rew = iter_predict(offline_state_d[_i*seg_size:(_i+1)*seg_size], offline_action_d[_i*seg_size:(_i+1)*seg_size], p_timesteps[_i*seg_size:(_i+1)*seg_size])

predicted_rew = {
    'expert_state': expert_state_d,
    'expert_action': expert_action_d,
    'expert_predicted': expert_rew.detach().cpu().numpy(),
    'expert_reward': expert_reward_v,
    'expert_cost': expert_cost_v,
    'offline_state': offline_state_d[_i*seg_size:(_i+1)*seg_size],
    'offline_action': offline_action_d[_i*seg_size:(_i+1)*seg_size],
    'offline_predicted': offline_rew.detach().cpu().numpy(),
    'offline_reward': offline_reward_v[_i*seg_size:(_i+1)*seg_size],
    'offline_cost': offline_cost_v[_i*seg_size:(_i+1)*seg_size],
}
# os.makedirs('labeled/novar-rew100-1-lr-shuffle', exist_ok=True)
# np.savez(f'labeled/novar-rew100-1-lr-shuffle/predicted_rew_{_i}.npz', **predicted_rew)
# os.makedirs('labeled/novar-rew100-1-lr-34', exist_ok=True)
# np.savez(f'labeled/novar-rew100-1-lr-34/predicted_rew_{_i}.npz', **predicted_rew)
os.makedirs('labeled/var-random-19', exist_ok=True)
np.savez(f'labeled/var-random-19/predicted_rew_{_i}.npz', **predicted_rew)