import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

"""
QMIX Agent Evaluation Script
────────────────────────────
This script evaluates QMIX agents in the AppleGridMDP environment 
under three disruption protocols, as described in the paper. It runs 
multiple episodes, applies disruptions at predefined timesteps, and 
computes cooperative resilience metrics (hunger, equality, cumulative 
rewards, last-apple events, etc.). Results are stored for analysis and 
can optionally be saved to disk.
"""

import math
import random
import numpy as np
import matplotlib.pyplot as plt

from src.utils.save_and_load import save_data
from src.metrics.hunger_and_equality import calculate_equality, calculate_hunger
from src.metrics.resilience_metrics import ResilienceMetrics
from src.agents.qmix.get_policy import reset_policy, act
from src.environment.test.apple_grid_qmix import AppleGridMDP

env = AppleGridMDP()
obs = env.reset()

n = 2
range_episode = 5000
num_episodes = 5

total_rewards_agent_1 = []
total_rewards_agent_2 = []
resiliencia_by_episode = []
minimum_rewards_agent = []
len_epsiode = []

random_state = 48
np.random.seed(random_state)
random.seed(random_state)

board_agent1_total = np.zeros((8, 8))
board_agent2_total = np.zeros((8, 8))
last_apple = 0

for episode in range(num_episodes):
    print(f'[Info] Start episode {episode}')
    env = AppleGridMDP(episode_limit=5000)
    env_baseline = AppleGridMDP(episode_limit=5000)
    reset_policy()

    pos1 = (random.randint(0, 7), random.randint(0, 7))
    while True:
        pos2 = (random.randint(0, 7), random.randint(0, 7))
        if pos2 != pos1:
            break

    env.agent_positions = [pos1, pos2]
    env_baseline.agent_positions = [pos1, pos2]

    rewards_agent_1 = np.zeros((range_episode,))
    rewards_agent_2 = np.zeros((range_episode,))
    rewards_agent_1_baseline = np.zeros((range_episode,))
    rewards_agent_2_baseline = np.zeros((range_episode,))

    apple_remaind = []
    apple_remaind_baseline = []
    len_epsiode_apple = 5000
    done = False
    cont = 0
    board_agent1 = np.zeros((8, 8))
    board_agent2 = np.zeros((8, 8))

    for i in range(range_episode):
        if i == 1250:
            env.trigger_disruption()
        if i == 2500:
            env.trigger_disruption_rate()
        if i == 2600:
            env.trigger_disruption_rate(start=False)
        if i == 3750:
            env.trigger_disruption_agent()
        if i == 3900:
            env.trigger_disruption_agent(start=False)

        if i < 1250:
            env_baseline.set_state(tuple(env.get_state()))
        else:
            avail_baseline = [env_baseline.get_avail_actions_agent(i)
                              for i in range(env_baseline.n_agents)] \
                if hasattr(env_baseline, "get_avail_actions_agent") else None
            obs_baseline = env_baseline.get_obs()
            actions_baseline = act(obs_baseline, avail_actions=avail_baseline)
            rewards_baseline, rewards_team_baseline, done_bs, _ = env_baseline.step(actions_baseline)

        avail = [env.get_avail_actions_agent(i) for i in range(env.n_agents)] \
            if hasattr(env, "get_avail_actions_agent") else None
        obs = env.get_obs()
        actions = act(obs, avail_actions=avail)
        rewards, reward_team, done, _ = env.step(actions)

        rewards_agent_1[i] = int(rewards[0])
        rewards_agent_2[i] = int(rewards[1])
        if i < 1250:
            rewards_agent_1_baseline[i] = int(rewards[0])
            rewards_agent_2_baseline[i] = int(rewards[1])
        else:
            rewards_agent_1_baseline[i] = int(rewards_baseline[0])
            rewards_agent_2_baseline[i] = int(rewards_baseline[1])

        if (done and i < 5000 - 1) and cont < 1:
            print(f'[Info] Last apple taken episode {episode}')
            len_epsiode_apple = i
            cont = 1

        apple_remaind.append(np.sum(env.grid))
        apple_remaind_baseline.append(np.sum(env_baseline.grid))

        if cont < 1:
            a11, a12 = env.agent_positions[0]
            a21, a22 = env.agent_positions[1]
            board_agent1[a11, a12] += 1
            board_agent2[a21, a22] += 1

    board_agent1_total += board_agent1 / np.max(board_agent1)
    board_agent2_total += board_agent2 / np.max(board_agent2)

    if len_epsiode_apple < 5000:
        last_apple += 1

    len_epsiode.append(len_epsiode_apple)

    hunger_ppo = calculate_hunger(range_episode, n, [rewards_agent_1, rewards_agent_2])
    hunger_baseline = calculate_hunger(range_episode, n, [rewards_agent_1_baseline, rewards_agent_2_baseline])
    gini_ppo = calculate_equality(range_episode, n, [rewards_agent_1, rewards_agent_2])
    gini_baseline = calculate_equality(range_episode, n, [rewards_agent_1_baseline, rewards_agent_2_baseline])

    resilience_metric = ResilienceMetrics(
        K=5,
        numberScenarios=1,
        assemblyIndicatorFuction='harmonic'
    )

    disturbancesIndex = {0: [1250, 2500, 3750]}

    Rset = {0: [list(apple_remaind_baseline), list(gini_baseline), list(hunger_baseline),
                list(np.cumsum(rewards_agent_1_baseline)), list(np.cumsum(rewards_agent_2_baseline))]}

    Pset = {0: [list(apple_remaind), list(gini_ppo), list(hunger_ppo),
                list(np.cumsum(rewards_agent_1)), list(np.cumsum(rewards_agent_2))]}

    resilience = resilience_metric.fit(
        disturbancesIndex=disturbancesIndex,
        Pset=Pset,
        Rset=Rset
    )

    total_rewards_agent_1.append(np.cumsum(rewards_agent_1)[-1])
    total_rewards_agent_2.append(np.cumsum(rewards_agent_2)[-1])
    minimum_rewards_agent.append(np.min(apple_remaind))
    resiliencia_by_episode.append(resilience)

resiliencia = [0 if math.isnan(float(u[0])) else float(u[0]) for u in resiliencia_by_episode]
plt.boxplot(resiliencia)
plt.show()

data_qmix = {
    'resilience': resiliencia,
    'minimum_reward': minimum_rewards_agent,
    'agent1_reward': total_rewards_agent_1,
    'agent2_reward': total_rewards_agent_2,
    'len': len_epsiode,
    'last_apple': last_apple,
    'board1': board_agent1_total,
    'board2': board_agent2_total,
}

# save_data(data_qmix, 'data_qmix.pkl')

reward_total = [0.5 * (a + b) for a, b in zip(total_rewards_agent_1, total_rewards_agent_2)]
print(f'[Metrics] Cooperative Resilience: mean = {np.mean(resiliencia):.3f}, std = {np.std(resiliencia):.3f}')
print(f'[Metrics] Total Agent Reward: mean = {np.mean(reward_total):.3f}, std = {np.std(reward_total):.3f}')
print(f'[Metrics] Episode Length: mean = {np.mean(len_epsiode):.1f}, std = {np.std(len_epsiode):.1f}')
print(f'[Metrics] Episodes without last apple consumed: {last_apple} out of {len(len_epsiode)}')
