import sys
import os
sys.path.append(os.path.abspath('./'))
sys.path.append(os.path.abspath('./env'))

import torch
import numpy as np
from datetime import datetime
from agilerl.algorithms.matd3 import MATD3
from pettingzoo.mpe import simple_push_v3, simple_tag_v3, simple_spread_v3, simple_reference_v3
from env_utils import save_gif, save_log, _label_with_episode_number
from load_agent import Load_model


def main(args, is_single_agent=True, is_save_gif=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Configure the environment
    if args.env_name == 'push':
        env = simple_push_v3.parallel_env(
            continuous_actions=False, render_mode="rgb_array"
        )
    elif args.env_name == 'tag':
        env = simple_tag_v3.parallel_env(
            num_good=1, num_adversaries=3, max_cycles=25, num_obstacles=2, continuous_actions=False, render_mode="rgb_array"  ##
        )
    elif args.env_name == 'spread':
        env = simple_spread_v3.parallel_env(
            N=3, max_cycles=25, continuous_actions=False, render_mode="rgb_array"
        )
    elif args.env_name == 'reference':
        env = simple_reference_v3.parallel_env(
            max_cycles=25, continuous_actions=False, render_mode="rgb_array"
        )
    
    env.reset()
    try:
        state_dim = [env.observation_space(agent).n for agent in env.agents]
        one_hot = True
    except Exception:
        state_dim = [env.observation_space(agent).shape for agent in env.agents]
        one_hot = False
    try:
        action_dim = [env.action_space(agent).n for agent in env.agents]
        discrete_actions = True
        max_action = None
        min_action = None
    except Exception:
        action_dim = [env.action_space(agent).shape[0] for agent in env.agents]
        discrete_actions = False #False
        max_action = [env.action_space(agent).high for agent in env.agents]
        min_action = [env.action_space(agent).low for agent in env.agents]

    # Append number of agents and agent IDs to the initial hyperparameter dictionary
    n_agents = env.num_agents
    agent_ids = env.agents

    # ---------------- Import agent ----------------
    model_list = []
    # Agent
    if args.eval.IsAgent:
        agent = Load_model(args, is_single_agent)        
        model_list.append(args.model_name)

    # Expert
    matd3 = MATD3(state_dim, action_dim, one_hot, n_agents, agent_ids, max_action, min_action, discrete_actions, device=device)
    path = args.eval.matd3_path
    matd3.loadCheckpoint(path)
    print(f'matd3: {path}')
    
    if is_single_agent or not args.eval.IsAgent:
        model_list.append(f'matd3: {path}')

    # ---------------- log prepare ----------------
    per_agent_avg_rewards = {}
    per_agent_min_rewards = {}
    per_agent_max_rewards = {}
    per_agent_win_rate = []
    
    log_lines = []
    today_str = datetime.now().strftime("%Y-%m-%d")
    log_lines.append(f"Data: {today_str}")
    log_lines.append(f"Total rounds: {args.eval.rounds} with per {args.eval.episodes} episodes")
    log_lines.append(f"Model: {model_list}")
    if args.model_name == 'ddgi':
        log_lines.append(f"sample step: {args.model.sample_steps} / alpha: {args.model.sample_alpha} / d weight: {args.model.d_weight}")
    log_lines.append("\n" + "*" * 50 + "\n")

    # ---------------- eval loop parameters ----------------
    total_episodes = args.eval.episodes
    max_steps = args.traj_length
    player_list = list(args.player_list)


    # ---------------- eval loop for inference ----------------
    for r in range(1, args.eval.rounds+1):
        env.reset(seed=int(r * 100))
        rewards = []
        frames = []
        agents_results = {name + '_reward': [] for name in player_list}
        episodic_results = {'episodic_reward': []}
        indi_agent_rewards = {agent_id: [] for agent_id in agent_ids}
        agent_win_num_per_round = 0 # agent win rate
        
        print(f'\n====================== round {r} ======================\n')
        
        log_lines.append("-" + f"round {r}" + "-")
        
        # ---------------------- begin one round ----------------------
        for ep in range(total_episodes):
            state, info = env.reset(seed=int(r * 100 + ep))
            agent_reward = {agent_id: 0 for agent_id in agent_ids}
            score = 0
            done = False

            for _ in range(max_steps):
                agent_mask = info["agent_mask"] if "agent_mask" in info.keys() else None
                env_defined_actions = (
                    info["env_defined_actions"]
                    if "env_defined_actions" in info.keys()
                    else None
                )

                # MADDPG
                cont_actions, discrete_action = matd3.getAction(
                    state,
                    epsilon=0,
                    agent_mask=agent_mask,
                    env_defined_actions=env_defined_actions,
                )
                if matd3.discrete_actions:
                    oact = discrete_action
                else:
                    oact = cont_actions
                
                infos = {'done': done, 'episodes': ep+1}
                # Agent
                # ============================================================
                if args.eval.IsAgent and is_single_agent:
                    if args.env_name == 'push':
                        if args.player_type == 'agent_0':
                            action = {'adversary_0': oact['adversary_0'], 
                                    'agent_0': agent.get_action(state['agent_0'], infos)}
                            
                        elif args.player_type == 'adversary_0':
                            action = {'adversary_0': agent.get_action(state['adversary_0'], infos), 
                                    'agent_0': oact['agent_0']}
                    

                    elif args.env_name == 'tag':
                        if args.player_type == 'agent_0':
                            action = {'adversary_0': oact['adversary_0'], 
                                    'adversary_1': oact['adversary_1'],
                                    'adversary_2': oact['adversary_2'],
                                    #'adversary_3': oact['adversary_3'],
                                    #'adversary_4': oact['adversary_4'],
                                    #'adversary_5': oact['adversary_5'], 
                                    #'adversary_6': oact['adversary_6'],
                                    #'adversary_7': oact['adversary_7'], 
                                    'agent_0': agent.get_action(state['agent_0'], infos)}
                            
                        elif args.player_type == 'adversary_0':
                            action = {'adversary_0': agent.get_action(state['adversary_0'], infos), 
                                    'adversary_1': oact['adversary_1'], 
                                    'adversary_2': oact['adversary_2'], 
                                    'agent_0': oact['agent_0']}
                            
                        elif args.player_type == 'adversary_1':
                            action = {'adversary_0': oact['adversary_1'], 
                                    'adversary_1': agent.get_action(state['adversary_1'], infos), 
                                    'adversary_2': oact['adversary_2'], 
                                    'agent_0': oact['agent_0']}
                            
                        elif args.player_type == 'adversary_2':
                            action = {'adversary_0': oact['adversary_0'],
                                    'adversary_1': oact['adversary_1'], 
                                    'adversary_2': agent.get_action(state['adversary_2'], infos), 
                                    'agent_0': oact['agent_0']}
                            

                    elif args.env_name == 'spread':
                        if args.player_type == 'agent_0':
                            action = {'agent_0': agent.get_action(state['agent_0'], infos), 
                                    'agent_1': oact['agent_1'], 
                                    'agent_2': oact['agent_2']} 
                            
                        elif args.player_type == 'agent_1':
                            action = {'agent_0': oact['agent_0'], 
                                    'agent_1': agent.get_action(state['agent_1'], infos), 
                                    'agent_2': oact['agent_2']} 
                            
                        elif args.player_type == 'agent_2':
                            action = {'agent_0': oact['agent_0'], 
                                    'agent_1': oact['agent_1'], 
                                    'agent_2': agent.get_action(state['agent_2'], infos)} 
                    
                    
                    elif args.env_name == 'reference':
                        if args.player_type == 'agent_0':
                            action = {'agent_0': agent.get_action(state['agent_0'], infos), 
                                    'agent_1': oact['agent_1']} 
                            
                        elif args.player_type == 'agent_1':
                            action = {'agent_0': oact['agent_0'], 
                                    'agent_1': agent.get_action(state['agent_1'], infos)} 
                            

                    done = True

                # ============================================================
                # every
                elif args.eval.IsAgent and not is_single_agent:
                    if args.env_name == 'push':
                        action = {'adversary_0': agent['adversary_0'].get_action(state['adversary_0'], infos), 
                                    'agent_0': agent['agent_0'].get_action(state['agent_0'], infos)}
                    

                    elif args.env_name == 'tag':
                        action = {'adversary_0': agent['adversary_0'].get_action(state['adversary_0'], infos), 
                                    'adversary_1': agent['adversary_1'].get_action(state['adversary_1'], infos),
                                    'adversary_2': agent['adversary_2'].get_action(state['adversary_2'], infos), 
                                    'agent_0': agent['agent_0'].get_action(state['agent_0'], infos)}
                    

                    elif args.env_name == 'spread':
                        action = {'agent_0': agent['agent_0'].get_action(state['agent_0'], infos), 
                                    'agent_1': agent['agent_1'].get_action(state['agent_1'], infos), 
                                    'agent_2': agent['agent_2'].get_action(state['agent_2'], infos)} 
                    
                    elif args.env_name == 'reference':
                        action = {'agent_0': agent['agent_0'].get_action(state['agent_0'], infos), 
                                    'agent_1': agent['agent_1'].get_action(state['agent_1'], infos)} 
                            
                    done = True

                # ============================================================
                elif not args.eval.IsAgent:
                    action = oact
                
                
                # Save the frame for this step and append to frames list
                #frame = env.render()
                #frames.append(_label_with_episode_number(frame, episode_num=ep))

                # Take action in environment
                state, reward, termination, truncation, info = env.step(action)

                # Save agent's reward for this step in this episode
                for agent_id, rr in reward.items():
                    agent_reward[agent_id] += rr
                score = sum(agent_reward.values())
                
                if any(truncation.values()) or any(termination.values()):
                    break
            
            rewards.append(score)
            # Record agent specific episodic reward
            for agent_id in agent_ids:
                indi_agent_rewards[agent_id].append(agent_reward[agent_id])

            print("-" * 15, f"Episode: {ep + 1}", "-" * 15)
            print("Episodic Reward: ", rewards[-1])
            episodic_results['episodic_reward'].append(rewards[-1])
            
            for agent_id, reward_list in indi_agent_rewards.items():
                print(f"{agent_id} reward: {reward_list[-1]}")
            
            for target_key, source_key in zip(agents_results.keys(), indi_agent_rewards.keys()):
                agents_results[target_key].append(indi_agent_rewards[source_key][-1])
                
            # win rates
            if args.env_name == 'tag': 
                adversary_score = indi_agent_rewards["adversary_0"][-1] + indi_agent_rewards["adversary_1"][-1] + indi_agent_rewards["adversary_2"][-1]
                agent_score = indi_agent_rewards["agent_0"][-1]
            elif args.env_name == 'push':
                adversary_score = indi_agent_rewards["adversary_0"][-1]
                agent_score = indi_agent_rewards["agent_0"][-1]
            
            if args.env_name in ['tag', 'push']:
                if adversary_score <= agent_score:
                    agent_win_num_per_round += 1

        env.close()
        agents_results.update(episodic_results)
    
        # ---------------------- end one round ----------------------
        print(f'\n====================================================\n')
        
        for columns, rewards in agents_results.items():
            rewards = np.array(rewards)
            avg = rewards.mean()
            min_r = rewards.min()
            max_r = rewards.max()
            
            print(f"{columns}: avg = {avg:.3f}, min = {min_r:.3f}, max = {max_r:.3f}")
            log_lines.append(f"{columns}: avg = {avg:.3f}, min = {min_r:.3f}, max = {max_r:.3f}")

            per_agent_avg_rewards.setdefault(columns, []).append(avg)
            per_agent_min_rewards.setdefault(columns, []).append(min_r)
            per_agent_max_rewards.setdefault(columns, []).append(max_r)
        
        per_agent_win_rate.append(agent_win_num_per_round)
        print(f'Agent win rate per round: {agent_win_num_per_round}')
            
        log_lines.append("-" * 50 + "\n")
        
        """if is_save_gif:
            save_gif(total_episodes, name, frames, gif_path = args.eval.gif_path) # gif """


    # ---------------------- Total Save ----------------------
    print(f'\n====================================================\n')
    log_lines.append("\n=== Total statistic ===\n")
    
    for columns in per_agent_avg_rewards:
        avg_arr = np.array(per_agent_avg_rewards[columns])
        min_arr = np.array(per_agent_min_rewards[columns])
        max_arr = np.array(per_agent_max_rewards[columns])

        avg_mean = avg_arr.mean()
        avg_std = avg_arr.std()

        min_mean = min_arr.mean()
        min_std = min_arr.std()

        max_mean = max_arr.mean()
        max_std = max_arr.std()
        
        print(f"{columns}:")
        print(f"  avg reward = {avg_mean:.3f} ± {avg_std:.3f}")
        print(f"  min reward = {min_mean:.3f} ± {min_std:.3f}")
        print(f"  max reward = {max_mean:.3f} ± {max_std:.3f}")

        log_lines.append(f"{columns}:")
        log_lines.append(f"  avg reward = {avg_mean:.3f} ± {avg_std:.3f}")
        log_lines.append(f"  min reward = {min_mean:.3f} ± {min_std:.3f}")
        log_lines.append(f"  max reward = {max_mean:.3f} ± {max_std:.3f}\n")
        
    log_lines.append(f'\n====================================================\n')
    
    # print win rate
    wr = np.array(per_agent_win_rate) / total_episodes
    avg_wr = wr.mean()
    std_wr = wr.std()
    print()
    print(f"Agent win rate = {avg_wr:.3f} ± {std_wr:.3f}")
    
    
    if args.eval.IsAgent:
        if is_single_agent:
            save_name = f"{args.model_name}_{args.player_type}_s{args.strength}_{args.eval.rounds}per{args.eval.episodes}.txt"
        else:
            save_name = f"{args.model_name}_all_s{args.strength}_{args.eval.rounds}per{args.eval.episodes}.txt"
    else:
        save_name = f"origin_{args.eval.episodes}.txt"
        
    log_text = "\n".join(log_lines)
    save_path = f"{args.eval.result_path}/{args.env_name}"
    save_log(log_text, save_name, save_path)
    
    if args.model_name == 'ddgi':
        agent.ouptut_lam()