import sys
import os
sys.path.append(os.path.abspath('./'))
sys.path.append(os.path.abspath('./env'))

from datetime import datetime
import numpy as np
import torch
from agilerl.algorithms.dqn import DQN
from connect4.trainer import Opponent
from pettingzoo.classic import connect_four_v3, leduc_holdem_v4, texas_holdem_v4
from env_utils import save_log, _label_with_episode_number_connect4, round_stats, stat_summary, round_stat_summary
from load_agent import Load_model


def main_connect(args, is_single_agent=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = connect_four_v3.env(render_mode="rgb_array")
    env.reset()

    # Configure the algo input arguments
    state_dim = [
        env.observation_space(agent)["observation"].shape for agent in env.agents
    ]
    one_hot = False
    action_dim = [env.action_space(agent).n for agent in env.agents]

    # Pre-process dimensions for pytorch layers
    # We will use self-play, so we only need to worry about the state dim of a single agent
    # We flatten the 6x7x2 observation as input to the agent's neural network
    state_dim = np.zeros(state_dim[0]).flatten().shape
    action_dim = action_dim[0]


    # ---------------------- Agent ----------------------
    model_list = []
    if args.eval.IsAgent:
        agent = Load_model(args, is_single_agent)
        
    # ---------------------------------------------------     
    opponent_difficulty = "inference" # ["random", "weak", "strong", "dqn"]
    
    # Create opponent
    if not args.eval.IsAgent:
        # Instantiate an DQN object
        #agent_0 = Opponent(env, "random")
        agent_0 = DQN(state_dim, action_dim, one_hot, device=device,)
        agent_0.loadCheckpoint('./env/connect4/weight/lesson4_agent.pt')
        #agent_1 = DQN(state_dim, action_dim, one_hot, device=device,)
        #agent_1.loadCheckpoint(args.eval.dqn_path)
        agent_1 = Opponent(env, "weak")
        a0, a1 = 'dqn', 'weak'
    
    elif is_single_agent:
        if args.eval.opponent != 'dqn':
            agent_0 = Opponent(env, args.eval.opponent)
        else:
            agent_0 = DQN(state_dim, action_dim, one_hot, device=device,)
            agent_0.loadCheckpoint(args.eval.dqn_path)
            print("load DQN")
        agent_1 = agent
        a0, a1 = args.eval.opponent, args.model_name
    
    elif not is_single_agent:
        agent_0 = agent['agent_0']
        agent_1 = agent['agent_1']
        a0, a1 = args.model_name, args.model_name
        
    model_list.append(a0)
    model_list.append(a1)
    
    # ---------------- log prepare ----------------
    round_win_counts = {'agent_0': 0, 'agent_1': 0, "tie": 0}
    round_win_rate_per_game = {'agent_0': [], 'agent_1': [], "tie": []}
    round_score_stats = {'agent_0': [], 'agent_1': []}
    round_lost_stats = {'agent_0': [], 'agent_1': []}
    round_counts = []
    
    log_lines = []
    today_str = datetime.now().strftime("%Y-%m-%d")
    log_lines.append(f"Data: {today_str}")
    log_lines.append(f"Total rounds: {args.eval.rounds} with per {args.eval.episodes} episodes")
    log_lines.append(f"Model: {model_list}")
    if args.model_name == 'ddgi':
        log_lines.append(f"sample step: {args.model.sample_steps} / alpha: {args.model.sample_alpha} / d weight: {args.model.d_weight}")
    log_lines.append("\n" + "*" * 50 + "\n")
    
    # Define eval loop parameters
    total_episodes = args.eval.episodes #2  # Number of episodes to eval agent on
    max_steps = 42  # Max number of steps to take in the environment in each episode
    player_list = list(args.player_list)


    print("============================================")
    print(f"Agent_0: {a0}")
    print(f"Agent_1: {a1}")

    for r in range(1, args.eval.rounds+1):
        env.reset(seed=int(r * 100))
        frames = []
        agents_results = {name + '_reward': [] for name in player_list}
        step_results = {'step': []}

        print(f'\n====================== round {r} ======================\n')
        log_lines.append("-" + f"round {r}" + "-")
        
        # eval loop for inference
        for ep in range(total_episodes):
            if ep / total_episodes < 0.5:
                opponent_first = False
                p = 1
            else:
                opponent_first = True
                p = 2
            if opponent_difficulty == "dqn":
                p = None

            env.reset(seed=int(r * 100 + ep))  # Reset environment at start of episode
            frame = env.render()
            frames.append(
                _label_with_episode_number_connect4(frame, episode_num=ep, frame_no=0, p=p)
            )
            observation, reward, done, truncation, _ = env.last()
            done = False
            player = -1  # Tracker for which player's turn it is
            score = 0

            for idx_step in range(max_steps):
                action_mask = observation["action_mask"]
                infos = {'done': done, 'episodes': ep+1, 'action_mask': action_mask}
                if player < 0:
                    state = np.moveaxis(observation["observation"], [-1], [-3])
                    state = np.expand_dims(state, 0)

                    if opponent_first:
                        action = action_agent_type(state, infos, opponent_first, agent_type=a0, agent=agent_0)
                    else:
                        action = action_agent_type(state, infos, opponent_first, agent_type=a1, agent=agent_1)


                if player > 0:
                    state = np.moveaxis(observation["observation"], [-1], [-3])
                    state[[0, 1], :, :] = state[[0, 1], :, :]
                    state = np.expand_dims(state, 0)

                    if not opponent_first:
                        action = action_agent_type(state, infos, opponent_first, agent_type=a0, agent=agent_0)
                    else:
                        action = action_agent_type(state, infos, opponent_first, agent_type=a1, agent=agent_1)
                        
                done = True
                env.step(action)  # Act in environment
                observation, reward, termination, truncation, _ = env.last()
                # Save the frame for this step and append to frames list
                frame = env.render()
                frames.append(
                    _label_with_episode_number_connect4(
                        frame, episode_num=ep, frame_no=idx_step, p=p
                    )
                )
                if reward == 1:
                    raise ValueError('Reward should be -1')
                if (player > 0 and opponent_first) or (player < 0 and not opponent_first):
                    score -= reward
                else:
                    score += reward

                # Stop episode if any agents have terminated
                if truncation or termination:
                    break

                player *= -1

            print("-" * 15, f"Episode: {ep+1}", "-" * 15)
            print(f"Episode length: {idx_step+1}")
            print(f"Score: {score}") # agent 1
            
            agents_results["agent_0_reward"].append(-score) # agent 0
            agents_results["agent_1_reward"].append(score) # agent 1
            step_results["step"].append(idx_step+1)
            #frames = resize_frames(frames, 0.5)
            print("============================================")

        env.close()
        print(f'\n====================================================\n')
            
        r0 = np.array(agents_results['agent_0_reward'])
        r1 = np.array(agents_results['agent_1_reward'])
        rounds = np.array(step_results["step"])
        k = len(r0)

        win0 = np.sum(r0 > r1)
        win1 = np.sum(r1 > r0)
        tie = np.sum(r0 == r1)
        
        round_win_counts['agent_0'] += win0
        round_win_counts['agent_1'] += win1
        round_win_counts["tie"] += tie
        
        round_on_score_0 = rounds[r0 == 1]
        round_on_score_1 = rounds[r1 == 1]
        round_on_lost_0 = rounds[r1 == 1]
        round_on_lost_1 = rounds[r0 == 1]
        
        round_score_stats['agent_0'].append(round_stats(round_on_score_0))
        round_lost_stats['agent_0'].append(round_stats(round_on_lost_0))
        round_score_stats['agent_1'].append(round_stats(round_on_score_1))
        round_lost_stats['agent_1'].append(round_stats(round_on_lost_1))
        
        round_win_rate_per_game['agent_0'].append(win0 / k)
        round_win_rate_per_game['agent_1'].append(win1 / k)
        round_win_rate_per_game["tie"].append(tie / k)
        round_counts.append(len(rounds))

        log_lines.append(f"{'agent_0'} - win rates: {win0} / {k} = {(win0 / k):.3f}")
        log_lines.append(f"{'agent_1'} - win rates: {win1} / {k} = {(win1 / k):.3f}")
        log_lines.append("-" * 50 + "\n")
        
        print(f'agent_0 win rates: {(win0 / k):.3f}')
        print(f'agent_1 win rates: {(win1 / k):.3f}')
    
    
    # ---------------------- Total Save ----------------------
    print(f'\n====================================================\n')
    log_lines.append("\n=== Total statistic ===\n")
    
    avg_win_rate = {
        key: np.mean(round_win_rate_per_game[key])
        for key in ['agent_0', 'agent_1', 'tie']
    }
    std_win_rate = {
        key: np.std(round_win_rate_per_game[key])
        for key in ['agent_0', 'agent_1', 'tie']
    }
    
    for key in ['agent_0', 'agent_1', "tie"]:
        log_lines.append(f"{key}: ")
        log_lines.append(f"  Total rounds: {round_win_counts[key]}")
        log_lines.append(f"  Avg win rates: {avg_win_rate[key]:.3f}")
        log_lines.append(f"  Std win rates: {std_win_rate[key]:.3f}\n")
    
    for key in ['agent_0', 'agent_1']:
        r = round_stat_summary(round_score_stats[key])
        log_lines.append(f"  Number of steps when get points: avg= {r['avg']:.2f}, min= {r['min']}, max= {r['max']}")
        r = round_stat_summary(round_lost_stats[key])
        log_lines.append(f"  Number of steps when losing points: avg= {r['avg']:.2f}, min= {r['min']}, max= {r['max']}")
        log_lines.append("")
    
    round_stat = stat_summary(round_counts)
    log_lines.append(f"   Number of rounds per game: avg= {round_stat['avg']:.2f}, min= {round_stat['min']}, max= {round_stat['max']}")
    log_lines.append(f'\n====================================================\n')
    
    
    # ---------------------- Save ----------------------
    if args.eval.IsAgent:
        if is_single_agent:
            save_name = f"{a0}_vs_{a1}_s{args.strength}_{args.eval.rounds}per{args.eval.episodes}_.txt"
        else:
            save_name = f"{args.model_name}_all_s{args.strength}_{args.eval.rounds}per{args.eval.episodes}.txt"
    else:
        save_name = f"origin_{args.eval.rounds}per{args.eval.episodes}.txt"
        
        
    log_text = "\n".join(log_lines)
    save_path = f"{args.eval.result_path}/{args.env_name}"
    save_log(log_text, save_name, save_path)
    
    """if args.model_name == 'ddgi':
        agent_1.ouptut_lam()"""
   
    
def action_agent_type(state, infos, opponent_first, agent_type, agent):
    if agent_type == "dqn":
        action = agent.getAction(state, epsilon=0, action_mask=infos['action_mask'])[0]
    elif agent_type == "random":
        action = agent.getAction(infos['action_mask'])
    elif agent_type in ["weak", "strong"]:
        if opponent_first:
            action = agent.getAction(player=0)
        else:
            action = agent.getAction(player=1)
    else:
        action = agent.get_action(state, infos)
    
    return action


# ====================================================================================


def main_holdem(args, is_single_agent=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = texas_holdem_v4.env(render_mode="rgb_array") # leduc_holdem_v4

    # ---------------------------------------------------
    # Agent
    model_list = []
    if args.eval.IsAgent:
        agent = Load_model(args, is_single_agent)
        
    # ---------------------------------------------------
    elif not args.eval.IsAgent:
        Agent_0 = torch.load(args.eval.nfsp_path)
        Agent_1 = torch.load(args.eval.nfsp_path)
        a0, a1 = 'nfsp', 'nfsp'
        print(f'nfsp: {args.eval.nfsp_path}')
        model_list.extend([f'nfsp: {args.eval.nfsp_path}', f'nfsp: {args.eval.nfsp_path}'])
    
    if is_single_agent:
        Agent_0 = agent
        Agent_1 = torch.load(args.eval.nfsp_path)
        a0, a1 = args.model_name, 'nfsp'
        print(f'nfsp: {args.eval.nfsp_path}')
        model_list.extend([args.model_name, f'nfsp: {args.eval.nfsp_path}'])
    
    elif not is_single_agent:
        Agent_0 = agent['agent_0']
        Agent_1 = agent['agent_1']
        a0, a1 = args.model_name, args.model_name
        model_list.extend([a0, a1])
    
    # ---------------- log prepare ----------------
    agent_0, agent_1 = list(args.player_list)
    round_win_counts = {'agent_0': 0, 'agent_1': 0, "tie": 0}
    round_win_rate_per_game = {'agent_0': [], 'agent_1': [], "tie": []}
    per_agent_avg_rewards = {}
    
    log_lines = []
    today_str = datetime.now().strftime("%Y-%m-%d")
    log_lines.append(f"Data: {today_str}")
    log_lines.append(f"Total rounds: {args.eval.rounds} with per {args.eval.episodes} episodes")
    log_lines.append(f"Model: {model_list}")
    if args.model_name == 'ddgi':
        log_lines.append(f"sample step: {args.model.sample_steps} / alpha: {args.model.sample_alpha} / d weight: {args.model.d_weight}")
    log_lines.append("\n" + "*" * 50 + "\n")

    # ---------------- eval loop parameters ----------------
    total_episodes = args.eval.episodes
    player_list = list(args.player_list)

    # ---------------- eval loop for inference ----------------
    for r in range(1, args.eval.rounds+1):
        env.reset(seed=int(r * 100))
        agents_results = {agent_id + '_reward': [] for agent_id in player_list}
        indi_agent_rewards = {agent_id: [] for agent_id in player_list}
        
        print(f'\n====================== round {r} ======================\n')
        log_lines.append("-" + f"round {r}" + "-")


        for ep in range(total_episodes):
            env.reset(seed=int(r * 100 + ep))
            agent_reward = {agent_id: 0 for agent_id in player_list}
            done = False

            for agent in env.agent_iter():
                state, reward, termination, truncation, _ = env.last()
                infos = {'done': done, 'episodes': ep+1, 'action_mask': state['action_mask']}
    
                if agent == 'player_0':
                    agent_reward['agent_0'] += reward 
                else:
                    agent_reward['agent_1'] += reward
                
                if termination or truncation:
                    action = None
                
                else:
                    if agent == 'player_0':
                        action = Agent_0.eval_step(state)[0] if a0 == 'nfsp' else Agent_0.get_action(state['observation'], infos)
                    elif agent == 'player_1':
                        action = Agent_1.eval_step(state)[0] if a1 == 'nfsp' else Agent_1.get_action(state['observation'], infos)
                
                env.step(action)
                done = True
            #env.close()

            # Record agent specific episodic reward
            for agent_id in player_list:
                indi_agent_rewards[agent_id].append(agent_reward[agent_id])

            print("-" * 15, f"Episode: {ep + 1}", "-" * 15)
            for agent_id, reward_list in indi_agent_rewards.items():
                print(f"{agent_id} reward: {reward_list[-1]}")
            
            for target_key, source_key in zip(agents_results.keys(), indi_agent_rewards.keys()):
                agents_results[target_key].append(indi_agent_rewards[source_key][-1])
                
        env.close()
        # ---------------------- end one round ----------------------
        print(f'\n====================================================\n')
        
        r0 = np.array(agents_results['agent_0_reward'])
        r1 = np.array(agents_results['agent_1_reward'])
        k = len(r0)

        avg_mean0 = r0.mean()
        avg_mean1 = r1.mean()

        win0 = np.sum(r0 > r1)
        win1 = np.sum(r1 > r0)
        tie = np.sum(r0 == r1)
        
        round_win_counts[agent_0] += win0
        round_win_counts[agent_1] += win1
        round_win_counts["tie"] += tie
        
        round_win_rate_per_game[agent_0].append(win0 / k)
        round_win_rate_per_game[agent_1].append(win1 / k)
        round_win_rate_per_game["tie"].append(tie / k)

        per_agent_avg_rewards.setdefault('agent_0', []).append(avg_mean0)
        per_agent_avg_rewards.setdefault('agent_1', []).append(avg_mean1)

        log_lines.append(f"{agent_0} - win rates: {win0} / {k} = {(win0 / k):.3f}")
        log_lines.append(f"{agent_1} - win rates: {win1} / {k} = {(win1 / k):.3f}")
        log_lines.append(f"{agent_0} - avg reward: {avg_mean0:.3f}")
        log_lines.append(f"{agent_1} - avg reward: {avg_mean1:.3f}")
        log_lines.append("-" * 50 + "\n")
        
        print(f'agent_0 win rates: {(win0 / k):.3f}')
        print(f"agent_0 reward: {avg_mean0:.3f}")
        print(f'agent_1 win rates: {(win1 / k):.3f}')
        print(f"agent_1 reward: {avg_mean1:.3f}")
    

    # ---------------------- Total Save ----------------------
    print(f'\n====================================================\n')
    log_lines.append("\n=== Total statistic ===\n")
    
    avg_win_rate = {
        key: np.mean(round_win_rate_per_game[key])
        for key in [agent_0, agent_1, "tie"]
    }
    std_win_rate = {
        key: np.std(round_win_rate_per_game[key])
        for key in [agent_0, agent_1, "tie"]
    }
    
    for key in [agent_0, agent_1, "tie"]:
        log_lines.append(f"{key}: ")
        log_lines.append(f"  Total rounds: {round_win_counts[key]}")
        log_lines.append(f"  Avg win rates: {avg_win_rate[key]:.3f}")
        log_lines.append(f"  Std win rates: {std_win_rate[key]:.3f}\n")
        print(f"{key}:")
        print(f"  Avg win rates = {avg_win_rate[key]:.3f}")

    for agent_id in per_agent_avg_rewards:
        avg_arr = np.array(per_agent_avg_rewards[agent_id])
        avg_mean = avg_arr.mean()
        avg_std = avg_arr.std()
        log_lines.append(f"{agent_id}:")
        log_lines.append(f"  avg reward = {avg_mean:.3f} ± {avg_std:.3f}")
        print(f"{agent_id}:")
        print(f"  Avg reward = {avg_mean:.3f} ± {avg_std:.3f}")
        
    log_lines.append(f'\n====================================================\n')
    
    if args.eval.IsAgent:
        if is_single_agent:
            save_name = f"{args.model_name}_{args.player_type}_s{args.strength}_{args.eval.rounds}per{args.eval.episodes}.txt"
        else:
            save_name = f"{args.model_name}_all_s{args.strength}_{args.eval.rounds}per{args.eval.episodes}.txt"
    else:
        save_name = f"origin_{args.eval.episodes}.txt"
        
    log_text = "\n".join(log_lines)
    save_path = f"{args.eval.result_path}/{args.env_name}"
    save_log(log_text, save_name, save_path)
    
    """if args.model_name == 'ddgi':
        Agent_0.ouptut_lam()"""