import os
import sys
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(base_path)

import environment
import numpy as np
import torch
import random
import argparse
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
from argparse import Namespace
from environment.wrapper import LMPromptEnv
from torch.nn import functional as F
from evaluate_test.minimal_exp_data.prompt_data import get_tsp_prompt_data
from dataloader.code.tokenizer import ContinuousScalarTokenizer
from dataloader.code.dataset import RLFullDataset, _get_loss_flag_and_position_id
from dataloader.code.input_specs import RLTaskInput
from typing import List
from utils.utils import set_seed, create_folder_overwrite_if_exist, str2bool, moving_average, load_model
from data.example.make_data import get_01bp_data, EXAMPLE_RENDER
from evaluate_test.evaluate_utils import masked_logits_for_action, truncate_sequence_by_stepsize

class VNet(torch.nn.Module):
    ''' 价值网络是一个两层 MLP '''
    def __init__(self, input_dim, hidden_dim):
        super(VNet, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class PPO(torch.nn.Module):
    def __init__(self, 
        gato, 
        state_dim, 
        hidden_dim, 
        actor_lr, 
        critic_lr, 
        lmbda, 
        eps, 
        gamma, 
        epochs, 
        args: Namespace,
    ):
        super().__init__()
        self.device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() and torch.cuda.device_count() >= args.device+1 else "cpu")
        self.actor = gato.to(self.device)
        self.critic = VNet(state_dim, hidden_dim).to(self.device) 
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        
        self.cont_tokenizer = ContinuousScalarTokenizer(self.args.tokenizer_ver, args.num_continuous_bin, args.discretize_mu, args.discretize_M)
        self.gamma = gamma
        self.lmbda = lmbda      # GAE 参数
        self.epochs = epochs    # 一条轨迹数据用来训练的轮数
        self.eps = eps          # PPO 中截断范围的参数   

    def take_action(
        self,
        args,
        env,
        current_seq,
        vision_seq,
        len_fixed_prompt,
        len_fixed_prompt_img,
        model_memory,
        prompt_strategy: str = "fixed_prompt",
        action_masks: List[np.ndarray] = None,
        sample_action: bool = False
    ):
        obs_length = env.obs_length
        action_length = env.action_length
        trans_size = action_length + obs_length + 1
        discrete_action = env.dataset.act_type_spec == 'int'
        
        # Generate action vectors dimension by dimension
        act_seq = []
        for i_act in range(action_length):
            if i_act == 0 or model_memory is None:
                act_flag, pos_id = _get_loss_flag_and_position_id(
                    0, len(current_seq) - 1, obs_length, action_length, prepend_trans_num=0
                )
            else:
                pos_id = np.array([0])
            
            # Constructs current_seq as an RLTaskInput that the model can enter
            x = RLTaskInput(
                tensor_seq=current_seq, # (seq_len, ) Note that when generating the first action (using all zero mem), seq_len can be longer than 1024
                vision_seq=vision_seq,  # None 
                text_seq=None,
                attention_mask=None,
                loss_mask=None,
                label=None,
                position_id=torch.tensor(pos_id, dtype=torch.long),
            )
            x.to(device=self.device)
            x.apply(lambda x: x[None, ...])
            
            # model one step generating
            res = self.actor(x, compute_loss=False, mems=model_memory)                  # (1, seq_len, output_size)
            logits = res[0]                                                             # (1, seq_len, total_vocab_size)
            if model_memory is not None:
                model_memory = res[-1]                                                  # n_layer * [(batch_size, mem_len, n_embed)]
                assert model_memory[0].shape[1] == args.n_position
                
            # Set mask according to the output space of current generating dim of action 
            logits = masked_logits_for_action(
                args, logits, discrete_action, 
                env_action_mask = None if action_masks is None else action_masks[i_act]
            )
            logits = logits[:, -1, :].squeeze()                                         # (total_vocab_size, )
            probs = F.softmax(logits, dim=-1)

            # get pred action token
            if sample_action:
                pred_token = torch.multinomial(probs, num_samples=1)
            else:
                _, pred_token = torch.topk(probs, k=1, dim=-1)
            
            # update current_seq
            if model_memory is None:
                current_seq = torch.cat([current_seq, pred_token.cpu()], dim=0)
                if len(current_seq) > args.n_position:
                    if args.use_prompt and prompt_strategy == "fixed_prompt":
                        # 如果设置 "fixed_prompt"，则维持序列首部的 expert prompt 不变
                        window_seq_view = torch.roll(current_seq[len_fixed_prompt:], -trans_size)   # 将current_seq中除prompt以外的序列循环左移trans_size
                        current_seq[len_fixed_prompt:].data.copy_(window_seq_view.data)             # 除prompt以外的序列中，首部那个transition对应的序列放在尾部
                        current_seq = current_seq[:-trans_size]                                     # 把上一步放到尾部的（原首部）transition对应序列去除

                        if vision_seq is not None:
                            window_img_view = torch.roll(vision_seq[len_fixed_prompt_img:], -1)
                            vision_seq[len_fixed_prompt_img:].data.copy_(window_img_view.data)
                            vision_seq = vision_seq[:-1]
                    else:
                        # 如果不用 prompt 序列或者设置 "moving_prompt"，则随着序列增长不断保留尾部序列
                        # 这样 prompt 序列会不断更新
                        current_seq, vision_seq = truncate_sequence_by_stepsize(current_seq, vision_seq, trans_size)
            else:
                # although cpu() may have a new copy, prevent side effect
                # of recover_model_predict_token_to_tokenizer_raw where
                # there are some inplace operations
                # XXX(DB1): memory net uses moving prompt!
                assert prompt_strategy != "fixed_prompt"
                current_seq = pred_token.cpu().clone()
                vision_seq = None

            # recover model predict token to tokenizer raw
            if not discrete_action:
                pred_token -= args.num_discrete_values
            act_seq.append(pred_token.cpu())

        # pass last dim of action to model to update model_memory
        if model_memory is not None:
            x = RLTaskInput(
                tensor_seq=current_seq,
                vision_seq=None,
                text_seq=None,
                attention_mask=None,
                loss_mask=None,
                label=None,
                position_id=torch.tensor([0], dtype=torch.long),
            )
            x.to(device=self.device)
            x.apply(lambda x: x[None, ...])

            _, _, model_memory = self.actor(x, compute_loss=False, mems=model_memory)
            assert model_memory[0].shape[1] == args.n_position

        # token -> action
        if not discrete_action:
            act = self.cont_tokenizer.decode(torch.cat(act_seq), is_action=True).numpy()
        else:
            act = np.concatenate(act_seq)
            act = env.dataset.adapter.recover_raw_act(act)
        return act, (current_seq, vision_seq), model_memory

    def compute_advantage(self, gamma, lmbda, td_delta):
        ''' 广义优势估计 GAE '''
        td_delta = td_delta.detach().numpy()
        advantage_list = []
        advantage = 0.0
        for delta in td_delta[::-1]:
            advantage = gamma * lmbda * advantage + delta
            advantage_list.append(advantage)
        advantage_list.reverse()
        return torch.tensor(np.array(advantage_list), dtype=torch.float)

    def update(self, transition_dict):
        states = torch.tensor(np.array(transition_dict['states']), dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(np.array(transition_dict['next_states']), dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)

        td_target = rewards + self.gamma * self.critic(next_states) * (1-dones)
        td_delta = td_target - self.critic(states)
        advantage = self.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)
        old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()

        # 用刚采集的一条轨迹数据训练 epochs 轮
        for _ in range(self.epochs):
            log_probs = torch.log(self.actor(states).gather(1, actions))
            ratio = torch.exp(log_probs - old_log_probs)
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage  # 截断
            actor_loss = torch.mean(-torch.min(surr1, surr2))                   # PPO损失函数
            critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
            
            # 更新网络参数
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()

if __name__ == "__main__":
    # eval paras
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt-path", type=str, default=None,)
    parser.add_argument("--seed", type=int, default=42,)
    parser.add_argument("--eval-iters-COP", type=int, default=100,)
    parser.add_argument("--eval-iters-RL", type=int, default=5,)
    parser.add_argument("--policy-logger", type=str2bool, default=False,)
    eval_args = parser.parse_args()
    eval_args.ckpt_path = '01BP_ORDINAL(MASK)_250_252_6_10/best/0.92_seed42_epoch100'
    eval_args.seed = 42
    eval_args.policy_logger = True
    eval_iters_dict = {
        'COPTask':eval_args.eval_iters_COP, 
        'RLTask':eval_args.eval_iters_RL
    }

    # load ckpt
    set_seed(eval_args.seed)
    config_path = eval_args.ckpt_path[:eval_args.ckpt_path.find('/')]
    args, gato, current_epoch = load_model(config_path=f'{base_path}/ckpt/{config_path}/config.json', para_path=f'{base_path}/ckpt/{eval_args.ckpt_path}.pt')
    args.seeds = [eval_args.seed, ]
    args.policy_logger = eval_args.policy_logger
    eval_prompt_strat = args.prompt_strategy.split(";")[-1]     # moving_prompt
    
    # build envs for evaluation
    datasets_prompt = get_01bp_data(data_type='prompt')
    envs = {dataset.env_name: LMPromptEnv(dataset.env_name, args, dataset, eval_prompt_strat) for dataset in datasets_prompt}    
    args.eval_env_names = [dataset.env_name for dataset in datasets_prompt]
    args.eval_dataset_names = [dataset.dataset_name for dataset in datasets_prompt]
    eval_iters = {env_name:eval_iters_dict[env.task_type] for env_name, env in envs.items()}
            
    # build episode render if we need to check generated episodes during training
    logger = None
    if args.policy_logger:
        for dataset in datasets_prompt:
            create_folder_overwrite_if_exist(f'{base_path}/visualize/eval/{dataset.env_name}/{dataset.dataset_name}')
        logger = {env_name: EXAMPLE_RENDER[env_name]() for env_name in args.eval_env_names}    
    
    # PPO paras  
    state_dim = datasets_prompt[0].obs_dim               # 环境观测维度
    actor_lr = 1e-3
    critic_lr = 1e-2
    hidden_dim = 256
    gamma = 0.98
    lmbda = 0.95
    epochs = 10
    eps = 0.2

    # training paras
    num_env = {'example_01BP':1, }   # 在不同环境上训练的轨迹量比例，要求 key 在 args.eval_dataset_names 中, value 必须为整数 
    num_episodes = 1000
    hard_action_constraint = False
    sample_action = True
    assert all(name in args.eval_env_names for name in num_env.keys()) 
    assert all(isinstance(value, int) for value in num_env.values())

    left_range = {env_name:0 for env_name in num_env.keys()}
    num_env_sum = 0
    for env_name, num in num_env.items():
        left_range[env_name] = num_env_sum
        num_env_sum += num

    # build agent
    agent = PPO(gato, state_dim, hidden_dim, actor_lr, critic_lr, lmbda, eps, gamma, epochs, args)

    # start training
    returns = {env_name:[] for env_name in args.eval_env_names}
    with tqdm(total=num_episodes, desc='PPO training') as pbar:
        for i in range(num_episodes):
            # get the env for training
            env_name = [k for k, v in left_range.items() if i % num_env_sum >= v][-1]
            env = envs[env_name]
            spliter_token = torch.tensor([env.dataset.spliter_token_id], dtype=torch.long)
            trans_size = env.obs_length + env.action_length + 1

            # reset env
            obs, (current_seq, current_img) = env.reset()
            obss, value_spaces = [obs,], [env.env.get_action_value_space(hard_action_constraint), ]
            current_seq = current_seq[None] if current_seq.ndim == 0 else current_seq
            
            # set prompt token sequence
            if args.use_prompt:
                fixed_prompt, prepend_img = env.get_prompt(strict_length=args.strict_length, minimal_expert_data=args.minimal_expert_data)

                # truncate sequence by stepsize
                len_fixed_prompt = len(fixed_prompt)                                        # NOTE(XXX): prompt长度可能不等于模型上下文长度1024
                len_fixed_prompt_img = len(prepend_img) if prepend_img is not None else 0   # num_env x num_trans x c x h x w
                current_seq = torch.cat([fixed_prompt, current_seq, spliter_token])         # 拼接当前obs和spliter，下一步用于自回归生成action

                if prepend_img is not None:
                    assert prepend_img.shape[1:] == current_img.shape[1:], (prepend_img.shape, current_img.shape,)
                    current_img = torch.cat([prepend_img, current_img], dim=0)
            else:
                len_fixed_prompt = 0
                len_fixed_prompt_img = 0

            # gen an episode with current policy(GPT)
            model_memory = gato.transformer.init_mem(batch_size=1)  
            episode_return, episode_length = 0, 0
            transition_dict = {
                'states': [],
                'actions': [],
                'next_states': [],
                'rewards': [],
                'dones': []
            }

            while True:
                act, (current_seq, current_img), model_memory = agent.take_action(
                    args=args,
                    env=env,
                    current_seq=current_seq,
                    vision_seq=current_img,
                    len_fixed_prompt=len_fixed_prompt,
                    len_fixed_prompt_img=len_fixed_prompt_img,
                    model_memory=model_memory,
                    prompt_strategy=eval_prompt_strat,
                    action_masks=env.get_action_mask(hard_action_constraint),
                    sample_action=sample_action
                )
                next_obs, (new_seq, new_img), reward, terminated, truncated, _ = env.step(act)
                new_seq = new_seq.unsqueeze(0) if new_seq.ndim == 0 else new_seq
                episode_return += reward
                episode_length += 1
                if isinstance(obs, dict):
                    transition_dict['states'].append(np.concatenate([v[None] if v.ndim==0 else v for v in list(obs.values())]))
                    transition_dict['next_states'].append(np.concatenate([v[None] if v.ndim==0 else v for v in list(next_obs.values())]))
                else:
                    transition_dict['states'].append(obs)
                    transition_dict['next_states'].append(next_obs)
                transition_dict['actions'].append(act)
                transition_dict['rewards'].append(reward)
                transition_dict['dones'].append(terminated or truncated)
                              
                if terminated or truncated or (args.eval_max_step_size is not None and episode_length >= args.eval_max_step_size):
                    break
                
                value_spaces.append(env.env.get_action_value_space(hard_action_constraint))
                obss.append(obs)
                obs = next_obs

                # update current_seq and current_img
                if model_memory is None:
                    if current_img is not None:
                        current_img = torch.cat([current_img, new_img], dim=0)
                    current_seq = torch.cat([current_seq, new_seq, spliter_token])
                        
                    if len(current_seq) > args.n_position:
                        # 如果设置 "fixed_prompt"，则维持序列首部的 expert prompt 不变
                        if args.use_prompt and eval_prompt_strat == "fixed_prompt":
                            window_seq_view = torch.roll(current_seq[len_fixed_prompt:], -trans_size)
                            current_seq[len_fixed_prompt:].data.copy_(window_seq_view.data)
                            current_seq = current_seq[:-trans_size]

                            if current_img is not None:
                                window_img_view = torch.roll(current_img[len_fixed_prompt_img:], -1)
                                current_img[len_fixed_prompt_img:].data.copy_(window_img_view.data)
                                current_img = current_img[:-1]
                        # 如果不用 prompt 序列或者设置 "moving_prompt"，则随着序列增长不断保留尾部序列
                        # 这样 prompt 序列会不断更新
                        else:
                            current_seq, current_img = truncate_sequence_by_stepsize(current_seq, current_img, env.obs_length, env.action_length)
                else:
                    current_seq = torch.cat([new_seq, spliter_token])
                    current_img = new_img

            episode = {
                'observations': obss,
                'actions': np.array(transition_dict['actions']).astype(env.dataset.act_type_spec),
                'rewards': np.array(transition_dict['rewards']).astype(np.float32),
                'act_value_space': value_spaces
            }


            # 用当前策略收集的数据进行 on-policy 更新
            #agent.update(transition_dict)

            # 更新进度条
            returns[env_name].append(episode_return)
            info = {'episode': i}
            info.update({f'ret_{env}': round(ret[-1],2) for env, ret in returns.items()})
            info.update({f'ret(ave)_{env}': round(np.mean(ret[-10:]),2) for env, ret in returns.items()})
            pbar.set_postfix(info)
            pbar.update(1)

    # show policy performence
    create_folder_overwrite_if_exist(f'{base_path}/visualize/PPO/{eval_args.ckpt_path}')
    for env_name in returns.keys():
        plt.figure(figsize=(12,8))
        return_list = returns[env_name]
        mv_return_list = moving_average(return_list, 29)
        episodes_list = list(range(len(return_list)))
        plt.plot(episodes_list, return_list, label='raw', alpha=0.5)
        plt.plot(episodes_list, mv_return_list, label='moving ave')
        plt.xlabel('Episodes')
        plt.ylabel('Returns')
        plt.title(f'PPO fineturn on {env_name}')
        plt.legend()
        plt.savefig(f'{base_path}/visualize/PPO/{eval_args.ckpt_path}/{env_name}.png')
        plt.show()                             
