import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

"""
PPO Agent Evaluation Script
────────────────────────────
This script evaluates PPO agents in the AppleGridMDP environment 
under three disruption protocols, as described in the paper. It runs 
multiple episodes, applies disruptions at predefined timesteps, and 
computes cooperative resilience metrics (hunger, equality, cumulative 
rewards, last-apple events, etc.). Results are stored for analysis and 
can optionally be saved to disk.
"""

import math
import random
import torch
import numpy as np
import matplotlib.pyplot as plt

from src.utils.save_and_load import save_data
from src.metrics.hunger_and_equality import calculate_equality, calculate_hunger
from src.metrics.resilience_metrics import ResilienceMetrics
from src.agents.ppo_agent import PPOAgent
from src.environment.test.apple_grid_ppo import AppleGridMDP

env = AppleGridMDP()

grid_size = env.grid_size[0] * env.grid_size[1]
new_agent = PPOAgent(4 + grid_size, 4)
new_agent_ = PPOAgent(4 + grid_size, 4)

# Load pretrained PPO agent models.
# Replace the paths with the PPO agents you want to evaluate.
# Both agents must be PPO (not QMIX)
new_agent.model.load_state_dict(torch.load("models/best/agent1.pth"))
new_agent_.model.load_state_dict(torch.load("models/best/agent2.pth"))

n = 2
range_episode = 5000
num_episodes = 500

total_rewards_agent_1 = []
total_rewards_agent_2 = []
resiliencia_by_episode = []
minimum_rewards_agent = []
len_epsiode = []
last_apple = 0

random_state = 48
np.random.seed(random_state)
random.seed(random_state)

board_agent1_total = np.zeros((8, 8))
board_agent2_total = np.zeros((8, 8))

for episode in range(num_episodes):
    print(f'[Info] Start episode {episode}')
    env = AppleGridMDP()
    env.reset()
    env_baseline = AppleGridMDP()
    env_baseline.reset()

    # Random initial positions
    pos1 = (random.randint(0, 7), random.randint(0, 7))
    while True:
        pos2 = (random.randint(0, 7), random.randint(0, 7))
        if pos2 != pos1:
            break
    env.agent_positions = [pos1, pos2]
    env_baseline.agent_positions = [pos1, pos2]

    state = env.get_state()
    state_baseline = env_baseline.get_state()

    rewards_agent_1 = np.zeros((range_episode,))
    rewards_agent_2 = np.zeros((range_episode,))
    rewards_agent_1_baseline = np.zeros((range_episode,))
    rewards_agent_2_baseline = np.zeros((range_episode,))

    apple_remaind = []
    apple_remaind_baseline = []
    len_epsiode_apple = range_episode

    board_agent1 = np.zeros((8, 8))
    board_agent2 = np.zeros((8, 8))

    for i in range(range_episode):
        if i == 1250:
            env.trigger_disruption()
        if i == 2500:
            env.trigger_disruption_rate()
        if i == 2600:
            env.trigger_disruption_rate(start=False)
        if i == 3750:
            env.trigger_disruption_agent()
        if i == 3900:
            env.trigger_disruption_agent(start=False)

        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)

        if i < 1250:
            env_baseline.set_state(tuple(state))
            state_baseline = state
        else:
            state_tensor_baseline = torch.tensor(state_baseline, dtype=torch.float32).unsqueeze(0)
            action1_bl, _ = new_agent.select_action(state_tensor_baseline)
            action2_bl, _ = new_agent_.select_action(state_tensor_baseline)
            actions_bl = [action1_bl, action2_bl]
            state_baseline, rewards_baseline = env_baseline.step(actions_bl)

        action1, _ = new_agent.select_action(state_tensor)
        action2, _ = new_agent_.select_action(state_tensor)
        actions = [action1, action2]
        state, rewards = env.step(actions)

        if rewards[0] >= 0:
            rewards_agent_1[i] = rewards[0]
        if rewards[1] >= 0:
            rewards_agent_2[i] = rewards[1]

        if i < 1250:
            rewards_agent_1_baseline[i] = rewards_agent_1[i]
            rewards_agent_2_baseline[i] = rewards_agent_2[i]
        else:
            if rewards_baseline[0] >= 0:
                rewards_agent_1_baseline[i] = rewards_baseline[0]
            if rewards_baseline[1] >= 0:
                rewards_agent_2_baseline[i] = rewards_baseline[1]

        if rewards[0] < 0 or rewards[1] < 0:
            print(f'[Info] Last apple taken episode {episode}')
            last_apple += 1
            len_epsiode_apple = i
        else:
            a11, a12 = env.agent_positions[0]
            a21, a22 = env.agent_positions[1]
            board_agent1[a11, a12] += 1
            board_agent2[a21, a22] += 1

        apple_remaind.append(np.sum(env.grid))
        apple_remaind_baseline.append(np.sum(env_baseline.grid))

    board_agent1_total += board_agent1 / np.max(board_agent1)
    board_agent2_total += board_agent2 / np.max(board_agent2)
    len_epsiode.append(len_epsiode_apple)

    hunger = calculate_hunger(range_episode, n, [rewards_agent_1, rewards_agent_2])
    hunger_baseline = calculate_hunger(range_episode, n, [rewards_agent_1_baseline, rewards_agent_2_baseline])
    gini_ppo = calculate_equality(range_episode, n, [rewards_agent_1, rewards_agent_2])
    gini_baseline = calculate_equality(range_episode, n, [rewards_agent_1_baseline, rewards_agent_2_baseline])

    resilience_metric = ResilienceMetrics(
        K=5,
        numberScenarios=1,
        assemblyIndicatorFuction='harmonic'
    )

    disturbancesIndex = {0: [1250, 2500, 3750]}
    Rset = {0: [list(apple_remaind_baseline), list(gini_baseline), list(hunger_baseline),
                list(np.cumsum(rewards_agent_1_baseline)), list(np.cumsum(rewards_agent_2_baseline))]}
    Pset = {0: [list(apple_remaind), list(gini_ppo), list(hunger),
                list(np.cumsum(rewards_agent_1)), list(np.cumsum(rewards_agent_2))]}

    resilience = resilience_metric.fit(
        disturbancesIndex=disturbancesIndex,
        Pset=Pset,
        Rset=Rset
    )

    total_rewards_agent_1.append(np.cumsum(rewards_agent_1)[-1])
    total_rewards_agent_2.append(np.cumsum(rewards_agent_2)[-1])
    minimum_rewards_agent.append(np.min(apple_remaind))
    resiliencia_by_episode.append(resilience)

resiliencia = [0 if math.isnan(float(u[0])) else float(u[0]) for u in resiliencia_by_episode]
plt.boxplot(resiliencia)
plt.show()

data_better = {
    'resilience': resiliencia,
    'minimum_reward': minimum_rewards_agent,
    'agent1_reward': total_rewards_agent_1,
    'agent2_reward': total_rewards_agent_2,
    'len': len_epsiode,
    'last_apple': last_apple,
    'board1': board_agent1_total,
    'board2': board_agent2_total,
}

#save_data(data_better, 'data_ppo.pkl')

reward_total = [0.5 * (a + b) for a, b in zip(total_rewards_agent_1, total_rewards_agent_2)]
print(f'[Metrics] Cooperative Resilience: mean = {np.mean(resiliencia):.3f}, std = {np.std(resiliencia):.3f}')
print(f'[Metrics] Total Agent Reward: mean = {np.mean(reward_total):.3f}, std = {np.std(reward_total):.3f}')
print(f'[Metrics] Episode Length: mean = {np.mean(len_epsiode):.1f}, std = {np.std(len_epsiode):.1f}')
print(f'[Metrics] Episodes without last apple consumed: {last_apple} out of {len(len_epsiode)}')