import os
import torch.nn.functional as F
import os
import numpy as np
import torch
import copy
from baseline_policy.ai_policy.policy_prefrence import Policy
from envs.Overcooked_Env_new import Overcooked_NEW
from envs.Overcooked_Env import Overcooked
from replay_buffer import ReplayBuffer

def get_script_ZSC_Distant_Tomato(env_name, ckpt_path, script_name=None, save_path=None, test_num=1):
    from hsp.envs.overcooked_new.script_agent import SCRIPT_AGENTS
    from hsp.envs.overcooked_new.src.overcooked_ai_py.mdp.actions import Action, Direction
    # 加载环境: Overcooked
    env = Overcooked_NEW(env_name,seed=3,featurize_type=("ppo","ppo"))
    # AI policy：加载 HSP算法中的模型
    policy = Policy(env_name)
    policy.load_checkpoint(ckpt_path)
    policy.prep_rollout()
    
   # 初始化 ReplayBuffer
    obs_shape = [2, 7*5*26]
    action_shape = [2, 1]
    reward_shape = [2, 1]
    dones_shape = [2, 1]
    replay_buffer = ReplayBuffer(obs_shape=obs_shape,
                                action_shape=action_shape,
                                reward_shape=reward_shape,
                                dones_shape=dones_shape,
                                capacity=400,
                                device='cuda')
    
    
    # 初始化
    obs, _, _ = env.reset() 
    obs = np.stack(obs) 
    rnn_state = np.zeros((1, 1, 64), dtype=np.float32)
    mask = np.ones((1, 1), dtype=np.float32)

    agent = SCRIPT_AGENTS[script_name]()
    agent.reset(env.base_mdp,env.base_env.state,0)
            
    episode_rewards = 0
    step = 0
    while True:
        agent_action = agent.step(env.base_mdp,env.base_env.state,0)
        agent_action = Action.ALL_ACTIONS.index(agent_action)
        now_obs = copy.deepcopy(obs)
        input_obs = torch.Tensor(now_obs).cuda() 
        ai_action,rnn_state = policy.act(input_obs[1:],rnn_state,mask)
        ai_action = ai_action.cpu().numpy()
        actions = np.array([[agent_action],ai_action[0]])
        next_obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        next_obs = np.stack(next_obs)
        # 数据存储
        input_save = now_obs.reshape(2,-1)
        output_save = next_obs.reshape(2,-1)
        replay_buffer.add(input_save, np.array([agent_action,ai_action[0][0]]).reshape(2,-1), np.array([rewards[-1],rewards[-1]]).reshape(2,-1), output_save, np.array(dones).reshape(2,-1))
        episode_rewards += rewards[-1]
        obs = copy.deepcopy(next_obs)
        if rewards[-1] >0:
            print('now:',step,rewards[-1])
        step += 1
        if dones[0] == True or dones[1] == True:
            print('episode_reward:',episode_rewards)
            save_name = str(r"distant_tomato")+'_'+str(test_num)
            replay_buffer.save(save_path, save_name) 
            episode_rewards = 0
            break
            #env.reset()

def get_script_ZSC_ManyOrders(env_name, ckpt_path, script_name=None, save_path=None, test_num=1):
    from hsp.envs.overcooked_new.script_agent import SCRIPT_AGENTS
    from hsp.envs.overcooked_new.src.overcooked_ai_py.mdp.actions import Action, Direction
    # 加载环境: Overcooked
    env = Overcooked_NEW(env_name,seed=3,featurize_type=("ppo","ppo"))
    # AI policy：加载 HSP算法中的模型
    policy = Policy(env_name)
    policy.load_checkpoint(ckpt_path)
    policy.prep_rollout()
    
   # 初始化 ReplayBuffer
    obs_shape = [2, 5*5*26]
    action_shape = [2, 1]
    reward_shape = [2, 1]
    dones_shape = [2, 1]
    replay_buffer = ReplayBuffer(obs_shape=obs_shape,
                                action_shape=action_shape,
                                reward_shape=reward_shape,
                                dones_shape=dones_shape,
                                capacity=400,
                                device='cuda')
    
    
    # 初始化
    obs, _, _ = env.reset() 
    obs = np.stack(obs) 
    rnn_state = np.zeros((1, 1, 64), dtype=np.float32)
    mask = np.ones((1, 1), dtype=np.float32)

    agent = SCRIPT_AGENTS[script_name]()
    agent.reset(env.base_mdp,env.base_env.state,0)
            
    episode_rewards = 0
    step = 0
    while True:
        agent_action = agent.step(env.base_mdp,env.base_env.state,0)
        agent_action = Action.ALL_ACTIONS.index(agent_action)
        now_obs = copy.deepcopy(obs)
        input_obs = torch.Tensor(now_obs).cuda() 
        ai_action,rnn_state = policy.act(input_obs[1:],rnn_state,mask)
        ai_action = ai_action.cpu().numpy()
        actions = np.array([[agent_action],ai_action[0]])
        next_obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        next_obs = np.stack(next_obs)
        # 数据存储
        input_save = now_obs.reshape(2,-1)
        output_save = next_obs.reshape(2,-1)
        replay_buffer.add(input_save, np.array([agent_action,ai_action[0][0]]).reshape(2,-1), np.array([rewards[-1],rewards[-1]]).reshape(2,-1), output_save, np.array(dones).reshape(2,-1))
        episode_rewards += rewards[-1]
        obs = copy.deepcopy(next_obs)
        if rewards[-1] >0:
            print('now:',step,rewards[-1])
        step += 1
        if dones[0] == True or dones[1] == True:
            print('episode_reward:',episode_rewards)
            save_name = str(r"manyOrders")+'_'+str(test_num)
            replay_buffer.save(save_path, save_name) 
            episode_rewards = 0
            break
            #env.reset()  

def get_script_ZSC_Random3(env_name, ckpt_path, script_name=None, save_path=None, test_num=1):
    from hsp.envs.overcooked.script_agent import SCRIPT_AGENTS
    from hsp.envs.overcooked.overcooked_ai_py.mdp.actions import Action, Direction
    # 加载环境: Overcooked
    env = Overcooked(env_name,seed=3,featurize_type=("ppo","ppo")) 
    # AI policy：加载 HSP算法中的模型
    policy = Policy(env_name)
    policy.load_checkpoint(ckpt_path)
    policy.prep_rollout()
    
   # 初始化 ReplayBuffer
    obs_shape = [2, 8*5*20]
    action_shape = [2, 1]
    reward_shape = [2, 1]
    dones_shape = [2, 1]
    replay_buffer = ReplayBuffer(obs_shape=obs_shape,
                                action_shape=action_shape,
                                reward_shape=reward_shape,
                                dones_shape=dones_shape,
                                capacity=400,
                                device='cuda')
    
    
    # 初始化
    obs, _, _ = env.reset() 
    obs = np.stack(obs) 
    rnn_state = np.zeros((1, 1, 64), dtype=np.float32)
    mask = np.ones((1, 1), dtype=np.float32)

    agent = SCRIPT_AGENTS[script_name]()
    agent.reset(env.base_mdp,env.base_env.state,0)
            
    episode_rewards = 0
    step = 0
    while True:
        agent_action = agent.step(env.base_mdp,env.base_env.state,0)
        agent_action = Action.ALL_ACTIONS.index(agent_action)
        now_obs = copy.deepcopy(obs)
        input_obs = torch.Tensor(now_obs).cuda() 
        ai_action,rnn_state = policy.act(input_obs[1:],rnn_state,mask)
        ai_action = ai_action.cpu().numpy()
        actions = np.array([[agent_action],ai_action[0]])
        next_obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        next_obs = np.stack(next_obs)
        # 数据存储
        input_save = now_obs.reshape(2,-1)
        output_save = next_obs.reshape(2,-1)
        replay_buffer.add(input_save, np.array([agent_action,ai_action[0][0]]).reshape(2,-1), np.array([rewards[-1],rewards[-1]]).reshape(2,-1), output_save, np.array(dones).reshape(2,-1))
        episode_rewards += rewards[-1]
        obs = copy.deepcopy(next_obs)
        if rewards[-1] >0:
            print('now:',step,rewards[-1])
        step += 1
        if dones[0] == True or dones[1] == True:
            print('episode_reward:',episode_rewards)
            save_name = str(r"random3")+'_'+str(test_num)
            replay_buffer.save(save_path, save_name) 
            episode_rewards = 0
            break
            #env.reset()
            
def get_script_ZSC_Soup_Coordination(env_name, ckpt_path, script_name=None, save_path=None, test_num=1):
    from hsp.envs.overcooked_new.script_agent import SCRIPT_AGENTS
    from hsp.envs.overcooked_new.src.overcooked_ai_py.mdp.actions import Action, Direction
    # 加载环境: Overcooked
    env = Overcooked_NEW(env_name,seed=3,featurize_type=("ppo","ppo"))
    # AI policy：加载 HSP算法中的模型
    policy = Policy(env_name)
    policy.load_checkpoint(ckpt_path)
    policy.prep_rollout()
    
   # 初始化 ReplayBuffer
    obs_shape = [2, 11*5*26]
    action_shape = [2, 1]
    reward_shape = [2, 1]
    dones_shape = [2, 1]
    replay_buffer = ReplayBuffer(obs_shape=obs_shape,
                                action_shape=action_shape,
                                reward_shape=reward_shape,
                                dones_shape=dones_shape,
                                capacity=400,
                                device='cuda')
    
    
    # 初始化
    obs, _, _ = env.reset() 
    obs = np.stack(obs) 
    rnn_state = np.zeros((1, 1, 64), dtype=np.float32)
    mask = np.ones((1, 1), dtype=np.float32)

    agent = SCRIPT_AGENTS[script_name]()
    agent.reset(env.base_mdp,env.base_env.state,0)
            
    episode_rewards = 0
    step = 0
    while True:
        agent_action = agent.step(env.base_mdp,env.base_env.state,0)
        agent_action = Action.ALL_ACTIONS.index(agent_action)
        now_obs = copy.deepcopy(obs)
        input_obs = torch.Tensor(now_obs).cuda() 
        ai_action,rnn_state = policy.act(input_obs[1:],rnn_state,mask)
        ai_action = ai_action.cpu().numpy()
        actions = np.array([[agent_action],ai_action[0]])
        next_obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        next_obs = np.stack(next_obs)
        # 数据存储
        input_save = now_obs.reshape(2,-1)
        output_save = next_obs.reshape(2,-1)
        replay_buffer.add(input_save, np.array([agent_action,ai_action[0][0]]).reshape(2,-1), np.array([rewards[-1],rewards[-1]]).reshape(2,-1), output_save, np.array(dones).reshape(2,-1))
        episode_rewards += rewards[-1]
        obs = copy.deepcopy(next_obs)
        if rewards[-1] >0:
            print('now:',step,rewards[-1])
        step += 1
        if dones[0] == True or dones[1] == True:
            print('episode_reward:',episode_rewards)
            save_name = str(r"soupCoordination")+'_'+str(test_num)
            replay_buffer.save(save_path, save_name) 
            episode_rewards = 0
            break
            #env.reset()
  
def get_script_ZSC_Unident_S(env_name, ckpt_path, script_name=None, save_path=None, test_num=1):
    from hsp.envs.overcooked.script_agent import SCRIPT_AGENTS
    from hsp.envs.overcooked.overcooked_ai_py.mdp.actions import Action, Direction
    # 加载环境: Overcooked
    env = Overcooked(env_name,seed=3,featurize_type=("ppo","ppo")) 
    # AI policy：加载 HSP算法中的模型
    policy = Policy(env_name)
    policy.load_checkpoint(ckpt_path)
    policy.prep_rollout()
    
   # 初始化 ReplayBuffer
    obs_shape = [2, 9*5*20]
    action_shape = [2, 1]
    reward_shape = [2, 1]
    dones_shape = [2, 1]
    replay_buffer = ReplayBuffer(obs_shape=obs_shape,
                                action_shape=action_shape,
                                reward_shape=reward_shape,
                                dones_shape=dones_shape,
                                capacity=400,
                                device='cuda')
    
    
    # 初始化
    obs, _, _ = env.reset() 
    obs = np.stack(obs) 
    rnn_state = np.zeros((1, 1, 64), dtype=np.float32)
    mask = np.ones((1, 1), dtype=np.float32)

    agent = SCRIPT_AGENTS[script_name]()
    agent.reset(env.base_mdp,env.base_env.state,0)
            
    episode_rewards = 0
    step = 0
    while True:
        agent_action = agent.step(env.base_mdp,env.base_env.state,0)
        agent_action = Action.ALL_ACTIONS.index(agent_action)
        now_obs = copy.deepcopy(obs)
        input_obs = torch.Tensor(now_obs).cuda() 
        ai_action,rnn_state = policy.act(input_obs[1:],rnn_state,mask)
        ai_action = ai_action.cpu().numpy()
        actions = np.array([[agent_action],ai_action[0]])
        next_obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        next_obs = np.stack(next_obs)
        # 数据存储
        input_save = now_obs.reshape(2,-1)
        output_save = next_obs.reshape(2,-1)
        replay_buffer.add(input_save, np.array([agent_action,ai_action[0][0]]).reshape(2,-1), np.array([rewards[-1],rewards[-1]]).reshape(2,-1), output_save, np.array(dones).reshape(2,-1))
        episode_rewards += rewards[-1]
        obs = copy.deepcopy(next_obs)
        if rewards[-1] >0:
            print('now:',step,rewards[-1])
        step += 1
        if dones[0] == True or dones[1] == True:
            print('episode_reward:',episode_rewards)
            save_name = str(r"unident_s")+'_'+str(test_num)
            replay_buffer.save(save_path, save_name) 
            episode_rewards = 0
            break
            #env.reset()

if __name__ == '__main__':
    # # Experiment layout
    # env_name=r'distant_tomato'
    # # save path
    # save_path = r"testesteste"
    # if not os.path.exists(save_path):
    #     os.makedirs(save_path)
    # # checkpoint path
    # ckpt_path = r"./baseline_policy/distant_tomato/fcp_adaptive_seed1.pt"
    # # skill_name
    # script_name = 'distant_tomato_place_onion_in_pot1'
    
    import sys
    print('parameters:', sys.argv)
    env_name = sys.argv[1]
    save_path = sys.argv[2]
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    ckpt_path = sys.argv[4]
    script_name = sys.argv[3]
    print('ckpt_path:',ckpt_path)
    # 脚本--ZSC
    print('************************')
    get_script_ZSC_Distant_Tomato(env_name,ckpt_path,script_name,save_path,test_num=1)
    print('************************')

    
    
    
    