import sys
import os
sys.path.append(os.path.abspath('./'))
sys.path.append(os.path.abspath('./env'))

import random
import numpy as np
import torch
import importlib
import supersuit as ss
from datetime import datetime
from env_utils import save_log, _label_with_episode_number
from load_agent import Load_model


def main(args, is_single_agent=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    random.seed(1)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.backends.cudnn.deterministic = True
    
    if args.env_name == 'tennis':
        env_id = "tennis_v3"
    elif args.env_name == 'box':
        env_id = "boxing_v2"
    else:
        raise ValueError("Atari only support 'tennis', 'box' ")
        
    env = importlib.import_module(f"pettingzoo.atari.{env_id}").parallel_env(render_mode="rgb_array")
    env = ss.max_observation_v0(env, 2)
    env = ss.frame_skip_v0(env, 4)
    env = ss.clip_reward_v0(env, lower_bound=-1, upper_bound=1)
    env = ss.color_reduction_v0(env, mode="B")
    env = ss.resize_v1(env, x_size=84, y_size=84)
    env = ss.frame_stack_v1(env, 4)
    env = ss.agent_indicator_v0(env, type_only=False)
    env = ss.pettingzoo_env_to_vec_env_v1(env)
    envs = ss.concat_vec_envs_v1(
        env, num_vec_envs=1, num_cpus=0, base_class="gymnasium"
        )
    envs.single_observation_space = envs.observation_space
    envs.single_action_space = envs.action_space
    envs.is_vector_env = True

    # ---------------------------------------------------
    # Agent
    model_list = []
    if args.eval.IsAgent:
        agent = Load_model(args, is_single_agent)        
        model_list.append(args.model_name)
        
    # ---------------------------------------------------

    # AC
    if args.env_name == 'tennis':
        from tennis.trainer import AgentTennis
        expert = AgentTennis(envs).to(device)

    elif args.env_name == 'box':
        from box.trainer import AgentBox
        expert = AgentBox(envs).to(device)
    
    expert.load_state_dict(torch.load(args.eval.ac_path, weights_only=True))
    expert.eval()
    if is_single_agent or not args.eval.IsAgent:
        model_list.append(f'PPO: {args.eval.ac_path}')


    # ---------------- log prepare ----------------
    agent_0, agent_1 = list(args.player_list)
    round_win_counts = {agent_0: 0, agent_1: 0, "tie": 0}
    round_win_rate_per_game = {agent_0: [], agent_1: [], "tie": []}
    per_agent_avg_rewards = {}
    per_agent_win_rate = []
    per_agent_win_round = []
    
    log_lines = []
    today_str = datetime.now().strftime("%Y-%m-%d")
    log_lines.append(f"Data: {today_str}")
    log_lines.append(f"Total rounds: {args.eval.rounds} with per {args.eval.episodes} episodes")
    log_lines.append(f"Model: {model_list}")
    if args.model_name == 'ddgi':
        log_lines.append(f"sample step: {args.model.sample_steps} / alpha: {args.model.sample_alpha} / d weight: {args.model.d_weight}")
    log_lines.append("\n" + "*" * 50 + "\n")
    
    # ---------------- eval loop parameters ----------------
    total_episodes = args.eval.episodes
    max_steps = args.traj_length * 2 if args.env_name == 'box' else args.traj_length
    player_list = list(args.player_list)


    # ---------------- eval loop for inference ----------------
    for r in range(1, args.eval.rounds+1):
        envs.reset(seed=int(r * 100))
        frames = []
        agents_results = {agent_id + '_reward': [] for agent_id in player_list}
        indi_agent_rewards = {agent_id: [] for agent_id in player_list}
        agent_win_num_per_round = 0 # agent win rate
        per_round = 0
        
        print(f'\n====================== round {r} ======================\n')
        log_lines.append("-" + f"round {r}" + "-")


        for ep in range(total_episodes):
            obs, _ = envs.reset(seed=int(r * 100 + ep))
            agent_reward = {agent_id: 0 for agent_id in player_list}
            done = False

            for _ in range(max_steps):
                with torch.no_grad():
                    obs = torch.Tensor(obs).to(device)
                    
                    # Actor
                    oact, _, _, _ = expert.get_action_and_value(obs)
                    # Agent
                    infos = {'done': done, 'episodes': ep+1}
                    if args.eval.IsAgent and is_single_agent:
                        if args.player_type == 'agent_0':
                            mact = agent.get_action(obs[0], infos)
                            if mact.dim() == 1:
                                mact = mact.squeeze(0)
                            action = torch.stack([mact, oact[1]])
                            
                        elif args.player_type == 'agent_1':
                            mact = agent.get_action(obs[1], infos)
                            if mact.dim() == 1:
                                mact = mact.squeeze(0)
                            action = torch.stack([oact[0], mact])
                    
                    elif args.eval.IsAgent and not is_single_agent:
                        act0 = agent['agent_0'].get_action(obs[0], infos)
                        act1 = agent['agent_1'].get_action(obs[1], infos)
                        if act0.dim() == 1 or act1.dim() == 1:
                            action = torch.cat([act0, act1], dim=-1) 
                        else: 
                            action = torch.stack([act0, act1])  
                        
                    else:
                        action = oact
                    
                    done = True
        
                frame = envs.render()
                frames.append(_label_with_episode_number(frame, ep))
                #print(action)
                obs, reward, termination, truncation, _ = envs.step(action.cpu().numpy())

                # Save agent's reward for this step in this episode
                for i, agent_id in enumerate(player_list):
                    agent_reward[agent_id] += reward[i]
                

                if termination[0] or truncation[0]:
                    break

            # Record agent specific episodic reward
            for agent_id in player_list:
                indi_agent_rewards[agent_id].append(agent_reward[agent_id])

            print("-" * 15, f"Episode: {ep + 1}", "-" * 15)
            
            sl = []
            for agent_id, reward_list in indi_agent_rewards.items():
                print(f"{agent_id} reward: {reward_list[-1]}")
                sl.append(reward_list[-1])
            
            for target_key, source_key in zip(agents_results.keys(), indi_agent_rewards.keys()):
                agents_results[target_key].append(indi_agent_rewards[source_key][-1])
                
            if sl[0] > sl[1]:
                agent_win_num_per_round += 1
            
            if sl[0] != sl[1]:
                per_round += 1

        envs.close()
        
        print(f'\n====================================================\n')
        
        r0 = np.array(agents_results['agent_0_reward'])
        r1 = np.array(agents_results['agent_1_reward'])
        k = len(r0)
        
        avg_mean0 = r0.mean()
        avg_mean1 = r1.mean()

        win0 = np.sum(r0 > r1)
        win1 = np.sum(r1 > r0)
        tie = np.sum(r0 == r1)
        
        round_win_counts[agent_0] += win0
        round_win_counts[agent_1] += win1
        round_win_counts["tie"] += tie
        
        round_win_rate_per_game[agent_0].append(win0 / k)
        round_win_rate_per_game[agent_1].append(win1 / k)
        round_win_rate_per_game["tie"].append(tie / k)
        
        per_agent_avg_rewards.setdefault('agent_0', []).append(avg_mean0)
        per_agent_avg_rewards.setdefault('agent_1', []).append(avg_mean1)

        log_lines.append(f"{agent_0} - win rates: {win0} / {k} = {(win0 / k):.3f}")
        log_lines.append(f"{agent_1} - win rates: {win1} / {k} = {(win1 / k):.3f}")
        log_lines.append(f"{agent_0} - avg reward: {avg_mean0:.3f}")
        log_lines.append(f"{agent_1} - avg reward: {avg_mean1:.3f}")
        log_lines.append("-" * 50 + "\n")
        
        print(f'agent_0 win rates: {(win0 / k):.3f}')
        print(f"agent_0 reward: {avg_mean0:.3f}")
        print(f'agent_1 win rates: {(win1 / k):.3f}')
        print(f"agent_1 reward: {avg_mean1:.3f}")
        
        per_agent_win_rate.append(agent_win_num_per_round)
        per_agent_win_round.append(per_round)
        print(f'Agent win rate per round: {agent_win_num_per_round / per_round}')
        
        """if is_save_gif:
            save_gif(total_episodes, name, frames, gif_path = args.eval.gif_path) # gif """


    # ---------------------- Total Save ----------------------
    print(f'\n====================================================\n')
    log_lines.append("\n=== Total statistic ===\n")
    
    avg_win_rate = {
        key: np.mean(round_win_rate_per_game[key])
        for key in [agent_0, agent_1, "tie"]
    }
    std_win_rate = {
        key: np.std(round_win_rate_per_game[key])
        for key in [agent_0, agent_1, "tie"]
    }
    
    for key in [agent_0, agent_1, "tie"]:
        log_lines.append(f"{key}: ")
        log_lines.append(f"  Total rounds: {round_win_counts[key]}")
        log_lines.append(f"  Avg win rates: {avg_win_rate[key]:.3f}")
        log_lines.append(f"  Std win rates: {std_win_rate[key]:.3f}\n")
        print(f"{key}:")
        print(f"  Avg win rates = {avg_win_rate[key]:.3f}")
    
    for agent_id in per_agent_avg_rewards:
        avg_arr = np.array(per_agent_avg_rewards[agent_id])
        avg_mean = avg_arr.mean()
        avg_std = avg_arr.std()
        log_lines.append(f"{agent_id}:")
        log_lines.append(f"  avg reward = {avg_mean:.3f} ± {avg_std:.3f}")
        print(f"{agent_id}:")
        print(f"  Avg reward = {avg_mean:.3f} ± {avg_std:.3f}")
        
    log_lines.append(f'\n====================================================\n')
    
    per_agent_win_rate = np.array(per_agent_win_rate)
    per_round = np.array(per_round)
    wr = per_agent_win_rate / per_round  
    avg_wr = wr.mean()
    std_wr = wr.std()


    print(f"Agent win rate = {avg_wr:.3f} ± {std_wr:.3f}")
    
    if args.eval.IsAgent:
        if is_single_agent:
            save_name = f"{args.model_name}_{args.player_type}_s{args.strength}_{args.eval.rounds}per{args.eval.episodes}.txt"
        else:
            save_name = f"{args.model_name}_all_s{args.strength}_{args.eval.rounds}per{args.eval.episodes}.txt"
    else:
        save_name = f"origin_{args.eval.episodes}.txt"
        
    log_text = "\n".join(log_lines)
    save_path = f"{args.eval.result_path}/{args.env_name}"
    save_log(log_text, save_name, save_path)
    
    """if args.model_name == 'ddgi':
        agent.ouptut_lam()"""