import argparse

import os
import uuid
import tqdm


import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import numpy as np
import pyrallis
import torch
import torch.nn.functional as F
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa
import torch.nn as nn

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# from examples.configs.bc_configs import BC_DEFAULT_CONFIG, BCTrainConfig
from osrl.algorithms import BC, BCTrainer
from osrl.common import TransitionDataset
from osrl.common.dataset import process_bc_dataset, process_realworld_dataset
from osrl.common.exp_util import auto_name, seed_all

from torch.nn.utils import clip_grad_norm_


task2episode_len = {
    # bullet safety gym envs
    "OfflineAntCircle-v0": 500,
    "OfflineAntRun-v0": 200,
    "OfflineCarCircle-v0": 300,
    "OfflineDroneCircle-v0": 300,
    "OfflineDroneRun-v0": 200,
    "OfflineBallCircle-v0": 200,
    "OfflineBallRun-v0": 100,
    "OfflineCarRun-v0": 200,
    # safety gymnasium: car
    "OfflineCarButton1Gymnasium-v0": 1000,
    "OfflineCarButton2Gymnasium-v0": 1000,
    "OfflineCarCircle1Gymnasium-v0": 500,
    "OfflineCarCircle2Gymnasium-v0": 500,
    "OfflineCarGoal1Gymnasium-v0": 1000,
    "OfflineCarGoal2Gymnasium-v0": 1000,
    "OfflineCarPush1Gymnasium-v0": 1000,
    "OfflineCarPush2Gymnasium-v0": 1000,
    # safety gymnasium: point
    "OfflinePointButton1Gymnasium-v0": 1000,
    "OfflinePointButton2Gymnasium-v0": 1000,
    "OfflinePointCircle1Gymnasium-v0": 500,
    "OfflinePointCircle2Gymnasium-v0": 500,
    "OfflinePointGoal1Gymnasium-v0": 1000,
    "OfflinePointGoal2Gymnasium-v0": 1000,
    "OfflinePointPush1Gymnasium-v0": 1000,
    "OfflinePointPush2Gymnasium-v0": 1000,
    # safety gymnasium: velocity
    "OfflineAntVelocityGymnasium-v1": 1000,
    "OfflineHalfCheetahVelocityGymnasium-v1": 1000,
    "OfflineHopperVelocityGymnasium-v1": 1000,
    "OfflineSwimmerVelocityGymnasium-v1": 1000,
    "OfflineWalker2dVelocityGymnasium-v1": 1000,
    # metadrive envs
    "OfflineMetadrive-easysparse-v0": 1000,
    "OfflineMetadrive-easymean-v0": 1000,
    "OfflineMetadrive-easydense-v0": 1000,
    "OfflineMetadrive-mediumsparse-v0": 1000,
    "OfflineMetadrive-mediummean-v0": 1000,
    "OfflineMetadrive-mediumdense-v0": 1000,
    "OfflineMetadrive-hardsparse-v0": 1000,
    "OfflineMetadrive-hardmean-v0": 1000,
    "OfflineMetadrive-harddense-v0": 1000,
}

import wandb
wandb.login(key='3d0a4921403bb5233bec6e1d9d55dcb01e30bfd2', relogin=True)


ROBOMIMIC = ['lift', 'can', 'square', 'tool_hang', 'transport']

def cycle(dl):
    while True:
        for data in dl:
            yield data

def boolean(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def gradient_reversal_hook(module, grad_input, grad_output):
    # print('--->>>')
    # print(grad_input.__len__())
    # for i, item in enumerate(grad_input):
    #     print('--->>>', i)
    #     print(item.shape)
    # 反转权重梯度，grad_input[1] 对应于权重的梯度
    # new_grad_input = (grad_input[0], -1.0 * grad_input[1])
    # if len(grad_input) > 2:  # 如果有偏置的梯度，也反转它
    #     new_grad_input += (-1.0 * grad_input[2],)
    new_grad_input = (-1.0 * grad_input[0], grad_input[1], -1.0 * grad_input[2])
    return new_grad_input

def main():
    # Parameters
    parser = argparse.ArgumentParser()
    # Define my dataset
    parser.add_argument('--project', default='OSRL-baselines', type=str)
    parser.add_argument('--group', default=None, type=str)
    parser.add_argument('--name', default=None, type=str)
    parser.add_argument('--suffix', default='', type=str)
    parser.add_argument('--logdir', default='logs', type=str)
    parser.add_argument('--verbose', default=True, type=bool)
    parser.add_argument('--task', default='OfflineCarCircle-v0', type=str)
    parser.add_argument('--cost_limit', default=10, type=int)
    parser.add_argument('--frontier_ratio', default=0.1, type=float)
    # Below is the old version
    parser.add_argument('--dataset', default='medium-replay', type=str)
    parser.add_argument('--state', default=False, type=boolean)
    parser.add_argument('--mismatch', default=False, type=boolean)
    parser.add_argument('--algo_type', default='smodice', type=str)
    parser.add_argument('--disc_type', default='learned', type=str)
    parser.add_argument('--gamma', default=0.99, type=float)

    parser.add_argument('--num_expert_traj', default=0, type=int)
    parser.add_argument('--num_offline_traj', default=2000, type=int)
    parser.add_argument('--total_iterations', default=int(1e6), type=int)
    parser.add_argument('--disc_iterations', default=int(1e3), type=int)
    parser.add_argument('--log_iterations', default=int(5e3), type=int)
    parser.add_argument('--episodes', default=10, type=int)

    parser.add_argument('--bc_only', default=False, type=boolean)
    parser.add_argument('--actor_deterministic', default=True, type=boolean)
    parser.add_argument('--absorbing_state', default=True, type=boolean)
    parser.add_argument('--standardize_reward', default=True, type=boolean)
    parser.add_argument('--standardize_obs', default=True, type=boolean)
    parser.add_argument('--reward_type', default='P', type=str, help='choose from T/P/C')
    parser.add_argument('--reward_scale', default=1, type=float)
    parser.add_argument('--res_scale', default=3, type=float)
    parser.add_argument('--mean_range', default=(-7.24, 7.24))
    parser.add_argument('--logstd_range', default=(-5., 2.))

    parser.add_argument('--hidden_sizes', default=(256, 256))
    parser.add_argument('--num_hidden', default=264, type=int)
    parser.add_argument('--batch_size', default=512, type=int)
    parser.add_argument('--f', default='kl', type=str)
    parser.add_argument('--lr', default=3e-4, type=float)
    parser.add_argument('--actor_lr', default=3e-4, type=float)
    parser.add_argument('--lr_ratio', default=0.001, type=float)
    parser.add_argument('--v_l2_reg', default=0.0001, type=float)
    parser.add_argument('--r_l2_reg', default=0.0001, type=float)
    parser.add_argument('--alpha', default=0.5, type=float)
    parser.add_argument('--use_policy_entropy_constraint', default=True, type=boolean)
    parser.add_argument('--target_entropy', default=None, type=float)

    parser.add_argument('--device', default='cuda:0', type=str)
    parser.add_argument('--wandb', default=True, type=boolean)
    parser.add_argument('--render', default=False, type=boolean)
    parser.add_argument('--seed', default=0, type=int)
    
    parser.add_argument('--extra_weight', default=1, type=float)
    parser.add_argument('--remark', default='adv', type=str)
    parser.add_argument('--n_bins', default=10, type=int)
    parser.add_argument('--n_repeat', default=4, type=int)
    
    args = parser.parse_args()

    # Set args.name, args.group, args.logdir
    # TODO: placeholder
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    wandb.init(
        project='OSRL-disc',
        group=None,
        name=None,
        id=str(uuid.uuid4()),
        resume="allow",
        config=None,  # type: ignore
    )

    import gymnasium as gym  # noqa
    if "Metadrive" in args.task:
        import gym
    env = gym.make(args.task)
    while True:
        success_load = True
        try:
            data = env.get_dataset()
        except:
            print('Fail to load data... One time')
            success_load = False
        if success_load:
            break
    env.set_target_cost(args.cost_limit)
    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]    
    max_ep_len = task2episode_len[args.task]
    batch_size = 10
    
    # obtain the real data    
    optimal_data, qualified_data = process_realworld_dataset(data, args.cost_limit, args.gamma, args.frontier_ratio, args.task)
    num_expert, num_offline = len(optimal_data['observations']), len(qualified_data['observations'])
    expert_state_d = torch.as_tensor(optimal_data['observations'].reshape(num_expert//max_ep_len, max_ep_len, state_dim)).to(args.device)
    expert_action_d = torch.as_tensor(optimal_data['actions'].reshape(num_expert//max_ep_len, max_ep_len, action_dim)).to(args.device)
    offline_state_d = torch.as_tensor(qualified_data['observations'].reshape(num_offline//max_ep_len, max_ep_len, state_dim)).to(args.device)
    offline_action_d = torch.as_tensor(qualified_data['actions'].reshape(num_offline//max_ep_len, max_ep_len, action_dim)).to(args.device)
    
    expert_rtg_d = torch.as_tensor(optimal_data['rewards'].reshape(num_expert//max_ep_len, max_ep_len).sum(axis=-1)).to(args.device)
    offline_rtg_d = torch.as_tensor(qualified_data['rewards'].reshape(num_offline//max_ep_len, max_ep_len).sum(axis=-1)).to(args.device)
   
    print('--->>>'*20)
    print('shape is ok!!!')
    print(num_expert, num_offline)
    print(max_ep_len)
    assert 0
    
    from preference_transformer import PreferenceTransformer
    transformer = PreferenceTransformer(state_dim=state_dim,
                                        act_dim=action_dim,
                                        hidden_size=args.num_hidden,
                                        max_ep_len=max_ep_len,
                                        n_bins=args.n_bins,
                                        device=args.device).to(args.device)
    
    # from preference_transformer import Preference2Transformer
    # transformer = Preference2Transformer(obs_dim=state_dim,
    #                                      action_dim=action_dim,
    #                                      embed_dim=args.num_hidden,
    #                                      pref_embed_dim=args.num_hidden,
    #                                      seq_len=max_ep_len,
    #                                      n_bins=args.n_bins).to(args.device)
    
    # if torch.cuda.device_count() > 1:
    #     transformer = nn.DataParallel(transformer)
    
    # optimizer = torch.optim.Adam(transformer.module.get_main_parameters(), lr=1e-4)
    # rew_optimizer = torch.optim.Adam(transformer.module.get_rew_output_parameters(), lr=1e-5)
    
    optimizer = torch.optim.Adam(transformer.get_main_parameters(), lr=3e-5)
    rew_optimizer = torch.optim.Adam(transformer.get_rew_output_parameters(), lr=1e-6)
    
    criterion = nn.CrossEntropyLoss()

    disc_iterations = 10000
    transformer.train()
    loss, ce_loss, rtg_loss, rtg2_loss = 0, 0, 0, 0
    rtg11_loss, rtg12_loss = 0, 0
    grad_norm_log = 0
    var_loss = 0

    # split the reward/rtg
    n_bins = args.n_bins
    global_min = min(expert_rtg_d.min(), offline_rtg_d.min())
    global_medium = expert_rtg_d.min()
    global_max = max(expert_rtg_d.max(), offline_rtg_d.max())
    
    bins = torch.linspace(global_min, global_medium, steps=n_bins).to(expert_rtg_d.device)
    bins = torch.cat([bins, torch.tensor([global_max]).to(expert_rtg_d.device)], dim=0)
    
    # bins = torch.linspace(global_min, global_max, steps=n_bins + 1).to(expert_rtg_d.device)
    
    expert_rtg_label = torch.bucketize(expert_rtg_d, bins, right=True) - 1
    offline_rtg_label = torch.bucketize(offline_rtg_d, bins, right=True) - 1
    
    expert_rtg_label[expert_rtg_label == n_bins] = n_bins - 1
    offline_rtg_label[offline_rtg_label == n_bins] = n_bins - 1
    expert_rtg_label[expert_rtg_label == -1] = 0
    offline_rtg_label[offline_rtg_label == -1] = 0
   
    _dict = {} 
    for _idx in range(n_bins):
        _dict[_idx] = torch.where(offline_rtg_label == _idx)[0]
    
    # def generate_new_tensor_vectorized(old_tensor):
    #     # 为old_tensor中的每个元素生成一个1到9之间的随机数（确保新值不等于旧值）
    #     offsets = torch.randint(1, n_bins, old_tensor.shape).to(old_tensor.device)
    #     # 使用模10操作来确保值仍然在0到9范围内
    #     new_tensor = (old_tensor + offsets) % n_bins
    #     return new_tensor

    def generate_random_label(old_tensor):
        # 为old_tensor中的每个元素生成一个0到9之间的随机数（确保新值不等于旧值）
        offsets = torch.randint(0, n_bins, old_tensor.shape).to(old_tensor.device)
        # 使用模10操作来确保值仍然在0到9范围内
        new_tensor = (old_tensor + offsets) % n_bins
        return new_tensor

    use_var = False
    disc_step = 300
    rew_step = 5
    for i in tqdm.tqdm(range(disc_iterations)):
        for _ in range(disc_step):
            _offline_indice = torch.randperm(offline_state_d.size(0))[:batch_size]
            _expert_indice = torch.randperm(expert_state_d.size(0))[:batch_size]
            _offline_state, _offline_action = offline_state_d[_offline_indice], offline_action_d[_offline_indice]
            _expert_state, _expert_action = expert_state_d[_expert_indice], expert_action_d[_expert_indice]
            # _expert_rtg, _offline_rtg = expert_rtg_d[_expert_indice], offline_rtg_d[_offline_indice]
            _expert_rtg_label, _offline_rtg_label = expert_rtg_label[_expert_indice], offline_rtg_label[_offline_indice]
            
            p_timesteps = torch.arange(0, max_ep_len).unsqueeze(0).repeat(len(_offline_state), 1).to(args.device)
            e_timesteps = torch.arange(0, max_ep_len).unsqueeze(0).repeat(len(_expert_state), 1).to(args.device)

            policy_d, policy_predrew = transformer(_offline_state, _offline_action, p_timesteps)
            expert_d, expert_predrew = transformer(_expert_state, _expert_action, e_timesteps)

            # The main discriminator loss
            expert_loss = F.binary_cross_entropy_with_logits(
                expert_d,
                torch.ones(expert_d.size()).to(args.device))
            learner_loss = F.binary_cross_entropy_with_logits(
                policy_d,
                torch.zeros(policy_d.size()).to(args.device))
            gail_loss = learner_loss + expert_loss
        
            # Reward agnostic loss
            rtg_pred_loss = criterion(policy_predrew, _offline_rtg_label) + criterion(expert_predrew, _expert_rtg_label)
            rtg_pred_loss1 = rtg_pred_loss.item()
            for _ in range(args.n_repeat):
                # _fkoffline_rtg_label = generate_new_tensor_vectorized(_offline_rtg_label)
                # _fkexpert_rtg_label = generate_new_tensor_vectorized(_expert_rtg_label)
                _fkoffline_rtg_label = generate_random_label(_offline_rtg_label)
                _fkexpert_rtg_label = generate_random_label(_expert_rtg_label)
                # Obtain the fake label
                # _shuffle1_indice, _shuffle2_indice = torch.randperm(expert_state_d.size(0))[:batch_size], torch.randperm(offline_state_d.size(0))[:batch_size]
                # _fkexpert_rtg_label, _fkoffline_rtg_label = expert_rtg_label[_shuffle1_indice], offline_rtg_label[_shuffle2_indice]
                rtg_pred_loss -= (criterion(policy_predrew, _fkoffline_rtg_label) + criterion(expert_predrew, _fkexpert_rtg_label)) / args.n_repeat
            
            # # Reward agnostic loss
            # rtg_pred_loss = criterion(policy_predrew, _offline_rtg_label) + criterion(expert_predrew, _expert_rtg_label) - \
            #     criterion(policy_predrew, _fkoffline_rtg_label) - criterion(expert_predrew, _fkexpert_rtg_label)
            
            tot_loss = gail_loss - args.extra_weight * rtg_pred_loss
            
            if use_var:
                offline_batch_size = 2
                # Compute variance loss
                # n_bins - 1
                _offline_bs = min(len(_dict[n_bins - 1]), offline_batch_size)
                _selected = _dict[n_bins - 1][torch.randperm(len(_dict[n_bins - 1]))[:_offline_bs]]
                _offline_state_selected, _offline_action_selected = offline_state_d[_selected], offline_action_d[_selected]
                policy_d_selected, _ = transformer(_offline_state_selected, _offline_action_selected)
                # _weights = torch.cat([expert_d.squeeze(-1), policy_d_selected.squeeze(-1)], dim=0)
                _weights = torch.cat([expert_d.mean(dim=0).detach().clone(), policy_d_selected.squeeze(-1)], dim=0)
                # _weights = _weights / torch.norm(_weights)
                var1 = -torch.clamp(torch.var(_weights), 100)
                # others
                # _idx = torch.randint(low=0, high=n_bins-1, size=(1,)).item()
                # _offline_bs = min(len(_dict[_idx]), batch_size)
                # _selected = _dict[_idx][torch.randperm(len(_dict[_idx]))[:_offline_bs]]
                # _offline_state_selected, _offline_action_selected = offline_state_d[_selected], offline_action_d[_selected]
                # policy_d_selected, _ = transformer(_offline_state_selected, _offline_action_selected)
                # _weights = policy_d_selected.squeeze(-1)
                # # _weights = torch.cat([policy_d_selected.squeeze(-1), expert_d_center], dim=0)
                # var2 = -torch.var(_weights)
                
                tot_loss = tot_loss + var1
            
            ce_loss += gail_loss.item()
            rtg_loss += rtg_pred_loss.item()
            rtg11_loss += rtg_pred_loss1
            rtg12_loss += rtg_pred_loss.item() - rtg_pred_loss1
            if use_var:
                var_loss += var1.item()
            loss += tot_loss.item()

            grad_pen = transformer.gradient_penalty(_expert_state, _expert_action, _offline_state, _offline_action)
            # grad_pen = transformer.module.gradient_penalty(_expert_state, _expert_action, _offline_state, _offline_action)

            optimizer.zero_grad()
            # grad_pen.backward()
            # print('--->>>> checking transformer gradient penalty gradient.')
            # for param in transformer.module.get_main_parameters():
            #     if param.grad is None:
            #         print(param.shape)
            # tot_loss.backward()
            (tot_loss + grad_pen).backward()        
            # 在执行优化器步骤之前，裁剪梯度
            # grad_norm = clip_grad_norm_(transformer.module.get_main_parameters(), max_norm=10.0)  # max_norm是梯度的最大范数
            # grad_norm_log += grad_norm.item()
            optimizer.step()
       
        wandb.log({
            'disc_loss': loss/disc_step,
            'ce_loss': ce_loss/disc_step,
            'rtg_loss': rtg_loss/disc_step,
            'rtg11_loss': rtg11_loss/disc_step,
            'rtg12_loss': rtg12_loss/disc_step,
            'var_loss': var_loss/disc_step if use_var else -1,
            'grad_norm': grad_norm_log/disc_step,
        }, step=i)
        loss, ce_loss, rtg_loss = 0, 0, 0
        rtg11_loss, rtg12_loss = 0, 0
        grad_norm_log = 0
        var_loss = 0
       
        for _ in range(rew_step):
            _offline_indice = torch.randperm(offline_state_d.size(0))[:batch_size]
            _expert_indice = torch.randperm(expert_state_d.size(0))[:batch_size]
            _offline_state, _offline_action = offline_state_d[_offline_indice], offline_action_d[_offline_indice]
            _expert_state, _expert_action = expert_state_d[_expert_indice], expert_action_d[_expert_indice]
            _expert_rtg_label, _offline_rtg_label = expert_rtg_label[_expert_indice], offline_rtg_label[_offline_indice]
      
            p_timesteps = torch.arange(0, max_ep_len).unsqueeze(0).repeat(len(_offline_state), 1).to(args.device)
            e_timesteps = torch.arange(0, max_ep_len).unsqueeze(0).repeat(len(_expert_state), 1).to(args.device)

            # train rew_output layer
            policy_d, policy_predrew = transformer(_offline_state, _offline_action, p_timesteps)
            expert_d, expert_predrew = transformer(_expert_state, _expert_action, e_timesteps)
            
            rtg_pred_loss = criterion(policy_predrew, _offline_rtg_label) + criterion(expert_predrew, _expert_rtg_label)
            rtg2_loss += rtg_pred_loss.item()
            
            rew_optimizer.zero_grad()
            rtg_pred_loss.backward()
            # clip_grad_norm_(transformer.module.get_rew_output_parameters(), max_norm=10.0)  # max_norm是梯度的最大范数
            rew_optimizer.step()
        
        wandb.log({
            'rtg2_loss': rtg2_loss/rew_step if rew_step != 0 else -1,
        }, step=i)
        rtg2_loss = 0

        
        if (i + 1) % 5 == 0:
            os.makedirs(os.path.join('logs', args.remark), exist_ok=True)
            torch.save(transformer, f'logs/{args.remark}/transformer_disc_{i}.pth')
        
    print('Build dataset successfully!')


if __name__ == '__main__':
    main()
