import argparse

import os
import uuid
import tqdm


import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import numpy as np
import pyrallis
import torch
import torch.nn.functional as F
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa
import torch.nn as nn

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# from examples.configs.bc_configs import BC_DEFAULT_CONFIG, BCTrainConfig
from osrl.algorithms import BC, BCTrainer
from osrl.common import TransitionDataset
from osrl.common.dataset import process_bc_dataset, process_realworld_dataset
from osrl.common.exp_util import auto_name, seed_all

from torch.nn.utils import clip_grad_norm_

import GPUtil

def cur_mem():
    gpus = GPUtil.getGPUs()
    return gpus[0].memoryUsed

import wandb
wandb.login(key='3d0a4921403bb5233bec6e1d9d55dcb01e30bfd2', relogin=True)


ROBOMIMIC = ['lift', 'can', 'square', 'tool_hang', 'transport']


def cycle(dl):
    while True:
        for data in dl:
            yield data

def boolean(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def gradient_reversal_hook(module, grad_input, grad_output):
    # print('--->>>')
    # print(grad_input.__len__())
    # for i, item in enumerate(grad_input):
    #     print('--->>>', i)
    #     print(item.shape)
    # 反转权重梯度，grad_input[1] 对应于权重的梯度
    # new_grad_input = (grad_input[0], -1.0 * grad_input[1])
    # if len(grad_input) > 2:  # 如果有偏置的梯度，也反转它
    #     new_grad_input += (-1.0 * grad_input[2],)
    new_grad_input = (-1.0 * grad_input[0], grad_input[1], -1.0 * grad_input[2])
    return new_grad_input

def main():
    # Parameters
    parser = argparse.ArgumentParser()
    # Define my dataset
    parser.add_argument('--project', default='OSRL-baselines', type=str)
    parser.add_argument('--group', default=None, type=str)
    parser.add_argument('--name', default=None, type=str)
    parser.add_argument('--suffix', default='', type=str)
    parser.add_argument('--logdir', default='logs', type=str)
    parser.add_argument('--verbose', default=True, type=bool)
    parser.add_argument('--task', default='OfflineCarCircle-v0', type=str)
    parser.add_argument('--cost_limit', default=10, type=int)
    parser.add_argument('--frontier_ratio', default=0.1, type=float)
    # Below is the old version
    parser.add_argument('--dataset', default='medium-replay', type=str)
    parser.add_argument('--state', default=False, type=boolean)
    parser.add_argument('--mismatch', default=False, type=boolean)
    parser.add_argument('--algo_type', default='smodice', type=str)
    parser.add_argument('--disc_type', default='learned', type=str)
    parser.add_argument('--gamma', default=0.99, type=float)

    parser.add_argument('--num_expert_traj', default=0, type=int)
    parser.add_argument('--num_offline_traj', default=2000, type=int)
    parser.add_argument('--total_iterations', default=int(1e6), type=int)
    parser.add_argument('--disc_iterations', default=int(1e3), type=int)
    parser.add_argument('--log_iterations', default=int(5e3), type=int)
    parser.add_argument('--episodes', default=10, type=int)

    parser.add_argument('--bc_only', default=False, type=boolean)
    parser.add_argument('--actor_deterministic', default=True, type=boolean)
    parser.add_argument('--absorbing_state', default=True, type=boolean)
    parser.add_argument('--standardize_reward', default=True, type=boolean)
    parser.add_argument('--standardize_obs', default=True, type=boolean)
    parser.add_argument('--reward_type', default='P', type=str, help='choose from T/P/C')
    parser.add_argument('--reward_scale', default=1, type=float)
    parser.add_argument('--res_scale', default=3, type=float)
    parser.add_argument('--mean_range', default=(-7.24, 7.24))
    parser.add_argument('--logstd_range', default=(-5., 2.))

    parser.add_argument('--hidden_sizes', default=(256, 256))
    parser.add_argument('--num_hidden', default=264, type=int)
    parser.add_argument('--batch_size', default=512, type=int)
    parser.add_argument('--f', default='kl', type=str)
    parser.add_argument('--lr', default=3e-4, type=float)
    parser.add_argument('--actor_lr', default=3e-4, type=float)
    parser.add_argument('--lr_ratio', default=0.001, type=float)
    parser.add_argument('--v_l2_reg', default=0.0001, type=float)
    parser.add_argument('--r_l2_reg', default=0.0001, type=float)
    parser.add_argument('--alpha', default=0.5, type=float)
    parser.add_argument('--use_policy_entropy_constraint', default=True, type=boolean)
    parser.add_argument('--target_entropy', default=None, type=float)

    parser.add_argument('--device', default='cuda:0', type=str)
    parser.add_argument('--wandb', default=True, type=boolean)
    parser.add_argument('--render', default=False, type=boolean)
    parser.add_argument('--seed', default=0, type=int)
    
    parser.add_argument('--extra_weight', default=1e-1, type=float)
    parser.add_argument('--remark', default='adv', type=str)
    parser.add_argument('--n_bins', default=10, type=int)
    parser.add_argument('--n_repeat', default=4, type=int)
    
    args = parser.parse_args()

    # Set args.name, args.group, args.logdir
    # TODO: placeholder
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    wandb.init(
        project='OSRL-predictor',
        group=None,
        name=None,
        id=str(uuid.uuid4()),
        resume="allow",
        config=None,  # type: ignore
    )

    import gymnasium as gym  # noqa
    if "Metadrive" in args.task:
        import gym
    env = gym.make(args.task)
    while True:
        success_load = True
        try:
            data = env.get_dataset()
        except:
            print('Fail to load data... One time')
            success_load = False
        if success_load:
            break
    env.set_target_cost(args.cost_limit)
    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]    
    max_ep_len = 300
    batch_size = 8
    
    # obtain the real data    
    optimal_data, qualified_data = process_realworld_dataset(data, args.cost_limit, args.gamma, args.frontier_ratio, args.task)
    num_expert, num_offline = len(optimal_data['observations']), len(qualified_data['observations'])
    expert_state_d = torch.as_tensor(optimal_data['observations'].reshape(num_expert//max_ep_len, max_ep_len, state_dim)).to(args.device)
    expert_action_d = torch.as_tensor(optimal_data['actions'].reshape(num_expert//max_ep_len, max_ep_len, action_dim)).to(args.device)
    offline_state_d = torch.as_tensor(qualified_data['observations'].reshape(num_offline//max_ep_len, max_ep_len, state_dim)).to(args.device)
    offline_action_d = torch.as_tensor(qualified_data['actions'].reshape(num_offline//max_ep_len, max_ep_len, action_dim)).to(args.device)
    
    # expert_rtg_d = torch.as_tensor(optimal_data['rewards'].reshape(num_expert//max_ep_len, max_ep_len).sum(axis=-1)).to(args.device)
    # offline_rtg_d = torch.as_tensor(qualified_data['rewards'].reshape(num_offline//max_ep_len, max_ep_len).sum(axis=-1)).to(args.device)

    expert_ctg_d = torch.as_tensor(optimal_data['costs'].reshape(num_expert//max_ep_len, max_ep_len).sum(axis=-1)).to(args.device)
    offline_ctg_d = torch.as_tensor(qualified_data['costs'].reshape(num_offline//max_ep_len, max_ep_len).sum(axis=-1)).to(args.device)

    # from structure_transformer import StructuredTransformer
    # transformer = StructuredTransformer(state_dim=state_dim,
    #                                     act_dim=action_dim,
    #                                     hidden_size=args.num_hidden,
    #                                     max_ep_len=max_ep_len,
    #                                     n_bins=args.n_bins,
    #                                     device=args.device).to(args.device)

    from structure_transformer import Structured2Transformer
    transformer = Structured2Transformer(obs_dim=state_dim,
                                         action_dim=action_dim,
                                         embed_dim=args.num_hidden,
                                         pref_embed_dim=args.num_hidden,
                                         seq_len=max_ep_len).to(args.device)
    
    # optimizer = torch.optim.Adam(transformer.get_main_parameters(), lr=3e-5)
    optimizer = torch.optim.Adam(transformer.parameters(), lr=3e-5)
    
    # criterion = nn.CrossEntropyLoss()

    disc_iterations = 400
    transformer.train()
    
    cost_log_loss = 0
    cost1_log_loss, cost2_log_loss = 0, 0
    for i in tqdm.tqdm(range(disc_iterations)):
        _offline_indice = torch.randperm(offline_state_d.size(0))[:batch_size]
        _expert_indice = torch.randperm(expert_state_d.size(0))[:batch_size]
        _offline_state, _offline_action = offline_state_d[_offline_indice], offline_action_d[_offline_indice]
        _expert_state, _expert_action = expert_state_d[_expert_indice], expert_action_d[_expert_indice]
        
        _expert_ctg_d, _offline_ctg_d = expert_ctg_d[_expert_indice], offline_ctg_d[_offline_indice]
        
        p_timesteps = torch.arange(0, max_ep_len).unsqueeze(0).repeat(len(_offline_state), 1).to(args.device)
        e_timesteps = torch.arange(0, max_ep_len).unsqueeze(0).repeat(len(_expert_state), 1).to(args.device)

        _, offline_ctg_pred = transformer(_offline_state, _offline_action, p_timesteps)
        _, expert_ctg_pred = transformer(_expert_state, _expert_action, e_timesteps)
        
        offline_ctg_pred, expert_ctg_pred = offline_ctg_pred.sum(dim=1).squeeze(-1), expert_ctg_pred.sum(dim=1).squeeze(-1)
        # cost_loss = ((offline_ctg_pred - _offline_ctg_d) ** 2 + (expert_ctg_pred - _expert_ctg_d) ** 2).mean()
        cost1_loss, cost2_loss = ((offline_ctg_pred - _offline_ctg_d) ** 2).mean(), ((expert_ctg_pred - _expert_ctg_d) ** 2).mean()
        cost_loss = cost1_loss + cost2_loss
        
        optimizer.zero_grad()
        cost_loss.backward()
        cost_log_loss += cost_loss.item()
        cost1_log_loss += cost1_loss.item()
        cost2_log_loss += cost2_loss.item()
        optimizer.step()
        
        if (i + 1) % 10 == 0:
            wandb.log({
                'cost_loss': cost_log_loss/10,
                'cost1_loss': cost1_log_loss/10,
                'cost2_loss': cost2_log_loss/10,
            }, step=i)
        cost_log_loss = 0
        cost1_log_loss, cost2_log_loss = 0, 0
        
        if (i + 1) % 100 == 0:
            os.makedirs(os.path.join('pred_logs', args.remark), exist_ok=True)
            torch.save(transformer, f'pred_logs/{args.remark}/transformer_pred_{i}.pth')   
        
    print('Build dataset successfully!')


if __name__ == '__main__':
    main()
