import sys
import os
sys.path.append(os.path.abspath('./env'))
import random
import numpy as np
import torch
import importlib
import supersuit as ss
from box.trainer import AgentBox
from tennis.trainer import AgentTennis
from env_utils import save_pickle


def rollout_Atari_main(env_name, agent_weight_path, all_episodes, max_steps):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(f"Collect data in env: '{env_name}'.")
    
    random.seed(1)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.backends.cudnn.deterministic = True

    # env setting
    if env_name == "box":
        env = importlib.import_module("pettingzoo.atari.boxing_v2").parallel_env(render_mode="rgb_array")
    elif env_name == "tennis":
        env = importlib.import_module("pettingzoo.atari.tennis_v3").parallel_env(render_mode="rgb_array")
    else:
        raise ValueError("Only use for env: 'box', 'tennis' in Atari game. ")

    env = ss.max_observation_v0(env, 2)
    env = ss.frame_skip_v0(env, 4)
    env = ss.clip_reward_v0(env, lower_bound=-1, upper_bound=1)
    env = ss.color_reduction_v0(env, mode="B")
    env = ss.resize_v1(env, x_size=84, y_size=84)
    env = ss.frame_stack_v1(env, 4)
    env = ss.agent_indicator_v0(env, type_only=False)
    env = ss.pettingzoo_env_to_vec_env_v1(env)
    envs = ss.concat_vec_envs_v1(
        env, num_vec_envs=1, num_cpus=0, base_class="gymnasium"
        )
    envs.single_observation_space = envs.observation_space
    envs.single_action_space = envs.action_space
    envs.is_vector_env = True

    # agent setting
    if env_name == "box":
        agent = AgentBox(envs).to(device)
    elif env_name == "tennis":
        agent = AgentTennis(envs).to(device)
    else:
        raise ValueError("Only use for env: 'box', 'tennis' in Atari game. ")
    agent.load_state_dict(torch.load(agent_weight_path, weights_only=True))
    agent.eval()


    # running data with env
    for episodes in all_episodes:
        Trajs = []
        collect_num = 0
        indi_agent_rewards = {"player_0": [], "player_1": []}  # Dictionary to collect inidivdual agent rewards
        
        # -------------------------------------------------------------
        while True:
            # print epoch
            if collect_num % 100 == 0:
                print(f"--Epoch: {collect_num} --")
                
            obs, _ = envs.reset()
            traj = []
            haveReward = False
            prev_obs = torch.zeros(2, 84, 84, 6).to(device)
            prev_action = torch.zeros(2, dtype=torch.long).to(device)

            # collect data
            for step in range(max_steps):
                with torch.no_grad():
                    obs = torch.Tensor(obs).to(device) # [2, 84, 84, 6]
                    action, _, _, _ = agent.get_action_and_value(obs) # [2]
                
                data = {
                    "prev_state": prev_obs,
                    "prev_action": prev_action,
                    "state": obs,
                    "action": action,
                }
                traj.append(data)
                prev_obs = obs
                prev_action = action
                
                obs, reward, termination, truncation, _ = envs.step(action.cpu().numpy())

                if all(reward):
                    if env_name == "box":
                        haveReward = True
                    elif env_name == "tennis":
                        if step >= 32:
                            haveReward = True
                            break
                        else:
                            traj = []
                            prev_obs = torch.zeros(2, 84, 84, 6).to(device)
                            prev_action = torch.zeros(2, dtype=torch.long).to(device)
            
                # Determine total score for the episode and then append to rewards list
                if termination[0] or truncation[0]:
                    break
            
            # not collect
            if reward[0] == reward[1] == 0 and not haveReward:
                #print("again")
                continue
            
            # Record agent specific episodic reward
            indi_agent_rewards['player_0'].append(reward[0])
            indi_agent_rewards['player_1'].append(reward[1])

            winner = max(indi_agent_rewards, key=lambda k: indi_agent_rewards[k][-1])
            traj.append(winner)
            Trajs.append(traj)
            collect_num += 1

            # check to stop
            if episodes == collect_num:
                break
        # -------------------------------------------------------------

        print(f'Collect number: {collect_num}')
        envs.close()

        # save
        save_pickle(Trajs_data = Trajs, 
                    env_name = env_name, 
                    episodes = episodes)
