import sys
import os
sys.path.append(os.path.abspath('./env'))
import numpy as np
import torch
from agilerl.algorithms.dqn import DQN
from connect4.trainer import Opponent
from pettingzoo.classic import connect_four_v3, texas_holdem_v4
from env_utils import save_pickle


def rollout_Classic_main(env_name, agent_weight_path, all_episodes, opponent_difficulty=None, max_steps=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    if env_name == 'connect4':
        env = connect_four_v3.env(render_mode="rgb_array")
        env.reset()
        connect4_env(env, agent_weight_path, all_episodes, opponent_difficulty, max_steps, device)
    elif env_name == 'holdem':
        env = texas_holdem_v4.env(render_mode="rgb_array")
        env.reset()
        holdem_env(env, agent_weight_path, all_episodes, device)


def holdem_env(env, agent_weight_path, all_episodes, device):
    Agent = torch.load(agent_weight_path)
    for episodes in all_episodes:
        Trajs = []
        ep = 0

        while True:
            traj = []
            reward_sums = {'player_0': 0, 'player_1': 0}
            env.reset()
            prev_state, prev_actions, prev_act = np.zeros(72), np.array([0, 0]), 0
            for agent in env.agent_iter():
                state, reward, termination, truncation, info = env.last()
                reward_sums[agent] += reward
                if termination or truncation:
                    action = None
                else:
                    action = Agent.eval_step(state)[0]
                    data = {
                        "prev_state": prev_state,
                        "prev_action": prev_actions,
                        "state": state['observation'],
                        "action": np.array([action, prev_act]),
                    }
                    traj.append(data)
                    prev_state = state['observation']
                    prev_actions = np.array([action, prev_act])
                    prev_act = action
                
                env.step(action)

            p1, p2 = reward_sums["player_0"], reward_sums["player_1"]
            winner = "player_0" if p1 > p2 else "player_1" if p2 > p1 else "tie"
            traj.append(winner)
            Trajs.append(traj)
            ep += 1
    
            # check to stop
            if ep == episodes:
                break

        print(f'Collect number: {ep}')
        env.close()

        # save
        save_pickle(Trajs_data = Trajs, 
                    env_name = 'holdem', 
                    episodes = episodes)
        


def connect4_env(env, agent_weight_path, all_episodes, opponent_difficulty, max_steps, device):
    # Configure the algo input arguments
    state_dim = [
        env.observation_space(agent)["observation"].shape for agent in env.agents
    ]
    one_hot = False
    action_dim = [env.action_space(agent).n for agent in env.agents]

    # Pre-process dimensions for pytorch layers
    # We will use self-play, so we only need to worry about the state dim of a single agent
    # We flatten the 6x7x2 observation as input to the agent's neural network
    state_dim = np.zeros(state_dim[0]).flatten().shape
    action_dim = action_dim[0]

    # Instantiate an DQN object
    dqn = DQN(
        state_dim,
        action_dim,
        one_hot,
        device=device,
    )
    # Load the saved algorithm into the DQN object
    dqn.loadCheckpoint(agent_weight_path)
    
    # Create opponent
    if opponent_difficulty == "self":
        opponent = dqn
    else:
        opponent = Opponent(env, opponent_difficulty)
    
    print("============================================")
    print(f"Agent: {agent_weight_path}")
    print(f"Opponent: {opponent_difficulty}")

    # Define test loop parameters
    for episodes in all_episodes:
        Trajs = []
        ep = 0
        
        # Test loop for inference
        while True: 
            if ep / episodes < 0.5:
                opponent_first = False
            else:
                opponent_first = True
                
            env.reset()  # Reset environment at start of episode
            observation, reward, done, truncation, _ = env.last()
            player = -1  # Tracker for which player's turn it is
            score = 0
            traj = []
            
            for idx_step in range(max_steps):
                action_mask = observation["action_mask"]
                if player < 0:
                    state = np.moveaxis(observation["observation"], [-1], [-3])
                    state = np.expand_dims(state, 0)
                    
                    # state -> [1, 2, 6, 7]
                    if opponent_first:
                        if opponent_difficulty == "self":
                            action = opponent.getAction(state, epsilon=0, action_mask=action_mask)[0]
                        elif opponent_difficulty == "random":
                            action = opponent.getAction(action_mask)
                        else:
                            action = opponent.getAction(player=0)
                    else:
                        action = dqn.getAction(state, epsilon=0, action_mask=action_mask)[0]
                    
                    # ------------ collect ----------------
                    if idx_step != 0:
                        data = {
                            "prev_state": prev_state,
                            "prev_action": prev_actions,
                            "state": state[0],
                            "action": np.array([action, prev_act]),
                        }
                        traj.append(data)

                    prev_act = action
                    prev_state = state[0]
                    prev_actions = np.array([action, prev_act])
                    # --------------------------------------
                        
                if player > 0:
                    state = np.moveaxis(observation["observation"], [-1], [-3])
                    state[[0, 1], :, :] = state[[0, 1], :, :]
                    state = np.expand_dims(state, 0)
                    if not opponent_first:
                        if opponent_difficulty == "self":
                            action = dqn.getAction(state, epsilon=0, action_mask=action_mask)[0]
                        elif opponent_difficulty == "random":
                            action = opponent.getAction(action_mask)
                        else:
                            action = opponent.getAction(player=1)
                    else:
                        action = dqn.getAction(state, epsilon=0, action_mask=action_mask)[0]
                    
                    # ------------ collect ----------------
                    data = {
                        "prev_state": prev_state,
                        "prev_action": prev_actions,
                        "state": state[0],
                        "action": np.array([action, prev_act]),
                    }
                    traj.append(data)
                    prev_act = action
                    prev_state = state[0]
                    prev_actions = np.array([action, prev_act])
                        
                    # --------------------------------------
                        
                env.step(action)  # Act in environment
                observation, reward, termination, truncation, _ = env.last()

                if (player > 0 and opponent_first) or (
                    player < 0 and not opponent_first
                ):
                    score += reward
                else:
                    score -= reward
                # Stop episode if any agents have terminated
                if truncation or termination:
                    break
                
                player *= -1
            
            winner = "adversary" if score < 0 else "agent"
            traj.append(winner)
            Trajs.append(traj)
            ep += 1
            
            # check to stop
            if episodes == ep:
                break
        
        print(f'Collect number: {ep}')
        env.close()

        # save
        save_pickle(Trajs_data = Trajs, 
                    env_name = 'connect4', 
                    episodes = episodes)