import json
import numpy as np
import torch
from tqdm import tqdm
from dataclasses import dataclass


descriptions = {
    "5m_vs_6m": [
        "- Allied Team Agent Configuration: five Marines(Marines are ranged units in StarCraft 2).",
        "- Enemy Team Agent Configuration: six Marines(Marines are ranged units in StarCraft 2).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "6h_vs_8z": [
        "- Allied Team Agent Configuration: six Hydras(Hydras are long-range attack units in StarCraft 2).",
        "- Enemy Team Agent Configuration: eight Zealots(Zealots are close-range attack units in StarCraft 2).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "2c_vs_64zg": [
        "- Allied Team Agent Configuration: two Colossi(Colossi are powerful ranged units in StarCraft 2).",
        "- Enemy Team Agent Configuration: sixty-four Zerglings(Zerglings are fast melee units in StarCraft 2).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "corridor": [
        "- Allied Team Agent Configuration: six Zealots(Zealots are close-range attack units in StarCraft 2).",
        "- Enemy Team Agent Configuration: twenty-four Zerglings(Zerglings are fast melee units in StarCraft 2).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "protoss_5_vs_5": [
        "- Allied Team Agent Configuration: five units consisting of a mix of Stalkers, Zealots, and Colossi (Stalkers are versatile ranged units, Zealots are close-range melee units, and Colossi are powerful long-range units in StarCraft 2).",
        "- Enemy Team Agent Configuration: five units consisting of a mix of Stalkers, Zealots, and Colossi (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "protoss_10_vs_10": [
        "- Allied Team Agent Configuration: ten units consisting of a mix of Stalkers, Zealots, and Colossi (Stalkers are versatile ranged units, Zealots are close-range melee units, and Colossi are powerful long-range units in StarCraft 2).",
        "- Enemy Team Agent Configuration: ten units consisting of a mix of Stalkers, Zealots, and Colossi (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "protoss_10_vs_11": [
        "- Allied Team Agent Configuration: ten units consisting of a mix of Stalkers, Zealots, and Colossi (Stalkers are versatile ranged units, Zealots are close-range melee units, and Colossi are powerful long-range units in StarCraft 2).",
        "- Enemy Team Agent Configuration: eleven units consisting of a mix of Stalkers, Zealots, and Colossi (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "protoss_20_vs_20": [
        "- Allied Team Agent Configuration: twenty units consisting of a mix of Stalkers, Zealots, and Colossi (Stalkers are versatile ranged units, Zealots are close-range melee units, and Colossi are powerful long-range units in StarCraft 2).",
        "- Enemy Team Agent Configuration: twenty units consisting of a mix of Stalkers, Zealots, and Colossi (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "protoss_20_vs_23": [
        "- Allied Team Agent Configuration: twenty units consisting of a mix of Stalkers, Zealots, and Colossi (Stalkers are versatile ranged units, Zealots are close-range melee units, and Colossi are powerful long-range units in StarCraft 2).",
        "- Enemy Team Agent Configuration: twenty-three units consisting of a mix of Stalkers, Zealots, and Colossi (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "terran_5_vs_5": [
        "- Allied Team Agent Configuration: five units consisting of a mix of Marines, Marauders, and Medivacs (Marines are ranged units, Marauders are close-range units, and Medivacs are support units in StarCraft 2).",
        "- Enemy Team Agent Configuration: five units consisting of a mix of Marines, Marauders, and Medivacs (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "terran_10_vs_10": [
        "- Allied Team Agent Configuration: ten units consisting of a mix of Marines, Marauders, and Medivacs (Marines are ranged units, Marauders are close-range units, and Medivacs are support units in StarCraft 2).",
        "- Enemy Team Agent Configuration: ten units consisting of a mix of Marines, Marauders, and Medivacs (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "terran_10_vs_11": [
        "- Allied Team Agent Configuration: ten units consisting of a mix of Marines, Marauders, and Medivacs (Marines are ranged units, Marauders are close-range units, and Medivacs are support units in StarCraft 2).",
        "- Enemy Team Agent Configuration: eleven units consisting of a mix of Marines, Marauders, and Medivacs (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "terran_20_vs_20": [
        "- Allied Team Agent Configuration: twenty units consisting of a mix of Marines, Marauders, and Medivacs (Marines are ranged units, Marauders are close-range units, and Medivacs are support units in StarCraft 2).",
        "- Enemy Team Agent Configuration: twenty units consisting of a mix of Marines, Marauders, and Medivacs (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "terran_20_vs_23": [
        "- Allied Team Agent Configuration: twenty units consisting of a mix of Marines, Marauders, and Medivacs (Marines are ranged units, Marauders are close-range units, and Medivacs are support units in StarCraft 2).",
        "- Enemy Team Agent Configuration: twenty-three units consisting of a mix of Marines, Marauders, and Medivacs (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "zerg_5_vs_5": [
        "- Allied Team Agent Configuration: five units consisting of a mix of Zerglings, Roaches, and Hydras (Zerglings are fast melee units, Roaches are ranged units, and Hydras are long-range units in StarCraft 2).",
        "- Enemy Team Agent Configuration: five units consisting of a mix of Zerglings, Roaches, and Hydras (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "zerg_10_vs_10": [
        "- Allied Team Agent Configuration: ten units consisting of a mix of Zerglings, Roaches, and Hydras (Zerglings are fast melee units, Roaches are ranged units, and Hydras are long-range units in StarCraft 2).",
        "- Enemy Team Agent Configuration: ten units consisting of a mix of Zerglings, Roaches, and Hydras (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "zerg_10_vs_11": [
        "- Allied Team Agent Configuration: ten units consisting of a mix of Zerglings, Roaches, and Hydras (Zerglings are fast melee units, Roaches are ranged units, and Hydras are long-range units in StarCraft 2).",
        "- Enemy Team Agent Configuration: eleven units consisting of a mix of Zerglings, Roaches, and Hydras (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "zerg_20_vs_20": [
        "- Allied Team Agent Configuration: twenty units consisting of a mix of Zerglings, Roaches, and Hydras (Zerglings are fast melee units, Roaches are ranged units, and Hydras are long-range units in StarCraft 2).",
        "- Enemy Team Agent Configuration: twenty units consisting of a mix of Zerglings, Roaches, and Hydras (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],

    "zerg_20_vs_23": [
        "- Allied Team Agent Configuration: twenty units consisting of a mix of Zerglings, Roaches, and Hydras (Zerglings are fast melee units, Roaches are ranged units, and Hydras are long-range units in StarCraft 2).",
        "- Enemy Team Agent Configuration: twenty-three units consisting of a mix of Zerglings, Roaches, and Hydras (same as the allied team).",
        "- Situation Description: The situation involves the allied team and the enemy team engaging in combat, where victory is achieved by defeating all the enemies.",
        "- Objective: Defeat all enemy agents while ensuring as many allied agents as possible survive.",
    ],
}


def create_context(env_name, trajectory_x, trajectory_y):
    context = [
        "You are a helpful and honest judge of good game playing and progress in the StarCraft Multi-Agent Challenge game. Always answer as helpfully as possible, while being truthful.",
        "If you don't know the answer to a question, please don't share false information.",
        "I'm looking to have you evaluate a scenario in the StarCraft Multi-Agent Challenge. Your role will be to assess how much the actions taken by multiple agents in a given situation have contributed to achieving victory.",
        "",
        "The basic information for the evaluation is as follows.",
        "",
        f"- Scenario: {env_name}",
    ]
    context += descriptions[env_name]
    context += [
        "",
        "I will provide you with two trajectories, and you should select the better trajectory based on the outcomes of these trajectories. Regarding the trajectory, it will inform you about the final states, and you should select the better case based on these two trajectories.",
        "",
    ]

    ally_healths_x, enemy_healths_x, total_reward_x, last_step_x = trajectory_x
    ally_healths_y, enemy_healths_y, total_reward_y, last_step_y = trajectory_y

    context += [
        "[Trajectory 1]",
        "1. Final State Information",
        f"\t1) Allied Agents Health: {', '.join([f'{x:.3f}' for x in ally_healths_x])}",
        f"\t2) Enemy Agents Health: {', '.join([f'{x:.3f}' for x in enemy_healths_x])}",
        f"\t3) Number of Allied Deaths: {len(ally_healths_x) - np.sum(ally_healths_x > 0)}",
        f"\t4) Number of Enemy Deaths: {len(enemy_healths_x) - np.sum(enemy_healths_x > 0)}",
        f"\t5) Total Reward: {total_reward_x:.3f}",
        f"2. Total Number of Steps: {last_step_x:.0f}",
    ]

    context += [
        "",
        "[Trajectory 2]",
        "1. Final State Information",
        f"\t1) Allied Agents Health: {', '.join([f'{x:.3f}' for x in ally_healths_y])}",
        f"\t2) Enemy Agents Health: {', '.join([f'{x:.3f}' for x in enemy_healths_y])}",
        f"\t3) Number of Allied Deaths: {len(ally_healths_y) - np.sum(ally_healths_y > 0)}",
        f"\t4) Number of Enemy Deaths: {len(enemy_healths_y) - np.sum(enemy_healths_y > 0)}",
        f"\t5) Total Reward: {total_reward_y:.3f}",
        f"2. Total Number of Steps: {last_step_y:.0f}",
    ]

    context = "\n".join(context).strip()
    prompt = [
        "Your task is to inform which one is better between [Trajectory1] and [Trajectory2] based on the information mentioned above. For example, if [Trajectory 1] seems better, output `Answer: #1`, and if [Trajectory 2] seems better, output `Answer: #2`. If it's difficult to judge or they seem similar, please output `Answer: #0`.",
        "",
        "Omit detailed explanations and just provide the answer."
    ]
    prompt = "\n".join(prompt).strip()
    return context, prompt