import sys
import os
sys.path.append(os.path.abspath('./env'))
import numpy as np
import torch
from agilerl.algorithms.matd3 import MATD3
from pettingzoo.mpe import simple_tag_v3, simple_push_v3, simple_spread_v3
from env_utils import save_pickle


def rollout_MPE_main(env_name, agent_weight_path, all_episodes, max_steps):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(f"Collect data in env: '{env_name}'. ")

    # env setting
    if env_name == 'tag':
        env = simple_tag_v3.parallel_env(max_cycles=max_steps, continuous_actions=False, render_mode="rgb_array")
    elif env_name == 'push':
        env = simple_push_v3.parallel_env(max_cycles=max_steps, continuous_actions=False, render_mode="rgb_array")
    elif env_name == 'spread':
        env = simple_spread_v3.parallel_env(max_cycles=max_steps, continuous_actions=False, render_mode="rgb_array")
    else:
        raise ValueError("Only use for env: 'tag', 'push', 'spread' in MPE game. ")

    env.reset()
    try:
        state_dim = [env.observation_space(agent).n for agent in env.agents]
        one_hot = True
    except Exception:
        state_dim = [env.observation_space(agent).shape for agent in env.agents]
        one_hot = False
    try:
        action_dim = [env.action_space(agent).n for agent in env.agents]
        discrete_actions = True
        max_action = None
        min_action = None
    except Exception:
        action_dim = [env.action_space(agent).shape[0] for agent in env.agents]
        discrete_actions = False # False
        max_action = [env.action_space(agent).high for agent in env.agents]
        min_action = [env.action_space(agent).low for agent in env.agents]

    # Append number of agents and agent IDs to the initial hyperparameter dictionary
    n_agents = env.num_agents
    agent_ids = env.agents

    # Instantiate an MADDPG object
    matd3 = MATD3(
        state_dim,
        action_dim,
        one_hot,
        n_agents,
        agent_ids,
        max_action,
        min_action,
        discrete_actions,
        device=device,
    )
    matd3.loadCheckpoint(agent_weight_path) # Load the saved algorithm into the MADDPG object

    # running data with env
    # push: {'adversary_0', 'agent_0'}
    # tag: {'adversary_0', 'adversary_1', 'adversary_2', 'agent_0'}
    
    for episodes in all_episodes:
        Trajs = []
        indi_agent_rewards = {agent_id: [] for agent_id in agent_ids} 
        collect_num = 0
        
        while True: 
            state, info = env.reset()
            agent_reward = {agent_id: 0 for agent_id in agent_ids}
            traj = []
            prev_state = np.zeros(sum(s[0] for s in state_dim), dtype=np.float32)
            prev_action = np.zeros(n_agents, dtype=np.int64)
            
            for step in range(max_steps):
                agent_mask = info["agent_mask"] if "agent_mask" in info.keys() else None
                env_defined_actions = (
                    info["env_defined_actions"]
                    if "env_defined_actions" in info.keys()
                    else None
                )
                
                cont_actions, discrete_action = matd3.getAction(
                    state,
                    epsilon=0,
                    agent_mask=agent_mask,
                    env_defined_actions=env_defined_actions,
                )
                if matd3.discrete_actions:
                    action = discrete_action
                else:
                    action = cont_actions    

                data = {
                    "prev_state": prev_state,
                    "prev_action": prev_action,
                    "state": np.concatenate([v for v in state.values()]),
                    "action": np.concatenate([np.atleast_1d(v) for v in action.values()]),
                }
                traj.append(data)
                prev_state = np.concatenate([v for v in state.values()])
                prev_action = np.concatenate([np.atleast_1d(v) for v in action.values()])
                
                # Take action in environment
                state, reward, termination, truncation, info = env.step(action)
                for agent_id, r in reward.items():
                    agent_reward[agent_id] += r

                # Stop episode if any agents have terminated
                if any(truncation.values()) or any(termination.values()):
                    break
            
            
            for agent_id in agent_ids:
                indi_agent_rewards[agent_id].append(agent_reward[agent_id])
            
            if env_name == 'tag': 
                adversary_score = indi_agent_rewards["adversary_0"][-1] + indi_agent_rewards["adversary_1"][-1] + indi_agent_rewards["adversary_2"][-1]
                agent_score = indi_agent_rewards["agent_0"][-1]
            elif env_name == 'push':
                adversary_score = indi_agent_rewards["adversary_0"][-1]
                agent_score = indi_agent_rewards["agent_0"][-1]
            
            if env_name == 'spread':
                winner = "agent"
            else:
                winner = "adversary" if adversary_score > agent_score else "agent"
            traj.append(winner)
            Trajs.append(traj)
            collect_num += 1

            # check to stop
            if episodes == collect_num:
                break
        
        print(f'Collect number: {collect_num}')
        env.close()

        # save
        save_pickle(Trajs_data = Trajs,
                    env_name = env_name, 
                    episodes = episodes)