import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

"""
PPO Agent Evaluation Script (16×16)
──────────────────────────────────
This script evaluates PPO agents in the AppleGridMDP 16×16 environment 
under disruption protocols. It runs multiple episodes, applies a disruption 
at a predefined timestep, and computes cooperative resilience metrics 
(hunger, equality, cumulative rewards, last-apple events, etc.). Results 
are stored for analysis and can optionally be saved to disk.
"""

import math
import random
import torch
import numpy as np
import matplotlib.pyplot as plt

from src.utils.save_and_load import save_data
from src.metrics.hunger_and_equality import calculate_equality, calculate_hunger
from src.metrics.resilience_metrics import ResilienceMetrics
from src.agents.ppo_large_agent import PPOAgent
from src.environment.test.apple_grid_large import AppleGridMDP16

# Environment and agent setup
env = AppleGridMDP16()
grid_size = env.grid_size[0] * env.grid_size[1]
input_dim = 8 + grid_size  # 4 agents × 2 coordinates + flattened grid
output_dim = 4             # action space

# Load 4 pretrained PPO agents
agents = [PPOAgent(input_dim=input_dim, output_dim=output_dim) for _ in range(4)]
for i, agent in enumerate(agents):
    agent.model.load_state_dict(torch.load(f"models/best/agent{i}_16_16.pth"))

# Experiment settings
n = 4
range_episode = 2000
num_episodes = 50
disruption_step = 300
last_apple_ppo = 0
last_apple_random = 0

# Logs
len_ppo, len_random = [], []
resiliencia_by_episode, resiliencia_by_episode_random = [], []
total_rewards_ppo = [[] for _ in range(n)]
total_rewards_random = [[] for _ in range(n)]

# Main loop
for episode in range(num_episodes):
    finish_ppo, finish_random = 0, 0

    # PPO environments
    env = AppleGridMDP16()
    env_baseline = AppleGridMDP16()
    env.reset()
    env_baseline.reset()

    # Random environments
    env_random = AppleGridMDP16()
    env_random_baseline = AppleGridMDP16()
    env_random.reset()
    env_random_baseline.reset()

    # Initial states
    state = env.get_state()
    state_baseline = env_baseline.get_state()
    state_random = env_random.get_state()
    state_random_baseline = env_random_baseline.get_state()

    # Initialize reward trackers
    rewards_ppo = [np.zeros(range_episode) for _ in range(n)]
    rewards_ppo_baseline = [np.zeros(range_episode) for _ in range(n)]
    rewards_random = [np.zeros(range_episode) for _ in range(n)]
    rewards_random_baseline = [np.zeros(range_episode) for _ in range(n)]

    # Apple counts
    apples_ppo, apples_ppo_baseline = [], []
    apples_random, apples_random_baseline = [], []

    for t in range(range_episode):
        if t == disruption_step:
            env.trigger_disruption()
            env_random.trigger_disruption()

        # PPO baseline
        if t < disruption_step:
            env_baseline.set_state(tuple(state))
            state_baseline = state
        else:
            actions_bl = [
                agent.select_action(torch.tensor(state_baseline, dtype=torch.float32).unsqueeze(0))[0]
                for agent in agents
            ]
            state_baseline, rewards_bl = env_baseline.step(actions_bl)

        # PPO agents
        actions = [
            agent.select_action(torch.tensor(state, dtype=torch.float32).unsqueeze(0))[0]
            for agent in agents
        ]
        state, rewards = env.step(actions)

        # Random agents
        actions_random = [random.randint(0, 3) for _ in range(n)]
        state_random, rewards_r = env_random.step(actions_random)

        # Random baseline
        if t < disruption_step:
            env_random_baseline.set_state(tuple(state_random))
        else:
            actions_random_bl = [random.randint(0, 3) for _ in range(n)]
            state_random_baseline, rewards_r_bl = env_random_baseline.step(actions_random_bl)

        # Store rewards
        for i in range(n):
            if rewards[i] >= 0:
                rewards_ppo[i][t] = rewards[i]
            if rewards_r[i] >= 0:
                rewards_random[i][t] = rewards_r[i]
            if t < disruption_step:
                rewards_ppo_baseline[i][t] = rewards_ppo[i][t]
                rewards_random_baseline[i][t] = rewards_random[i][t]
            else:
                if rewards_bl[i] >= 0:
                    rewards_ppo_baseline[i][t] = rewards_bl[i]
                if rewards_r_bl[i] >= 0:
                    rewards_random_baseline[i][t] = rewards_r_bl[i]

        # Last-apple events
        if any(r < 0 for r in rewards):
            last_apple_ppo += 1
            finish_ppo = t
        if any(r < 0 for r in rewards_r):
            last_apple_random += 1
            finish_random = t

        # Track apples
        apples_ppo.append(np.sum(env.grid))
        apples_ppo_baseline.append(np.sum(env_baseline.grid))
        apples_random.append(np.sum(env_random.grid))
        apples_random_baseline.append(np.sum(env_random_baseline.grid))

    # Episode lengths
    len_ppo.append(finish_ppo if finish_ppo > 0 else range_episode)
    len_random.append(finish_random if finish_random > 0 else range_episode)

    # Resilience metrics
    hunger_ppo = calculate_hunger(range_episode, n, rewards_ppo)
    hunger_ppo_baseline = calculate_hunger(range_episode, n, rewards_ppo_baseline)
    hunger_random = calculate_hunger(range_episode, n, rewards_random)
    hunger_random_baseline = calculate_hunger(range_episode, n, rewards_random_baseline)

    gini_ppo = calculate_equality(range_episode, n, rewards_ppo)
    gini_ppo_baseline = calculate_equality(range_episode, n, rewards_ppo_baseline)
    gini_random = calculate_equality(range_episode, n, rewards_random)
    gini_random_baseline = calculate_equality(range_episode, n, rewards_random_baseline)

    resilience_metric = ResilienceMetrics(K=5, numberScenarios=1, assemblyIndicatorFuction='harmonic')
    resilience = resilience_metric.fit(
        disturbancesIndex={0: [disruption_step]},
        Pset={0: [apples_ppo, gini_ppo, hunger_ppo] + [np.cumsum(rewards_ppo[i]) for i in range(n)]},
        Rset={0: [apples_ppo_baseline, gini_ppo_baseline, hunger_ppo_baseline] + [np.cumsum(rewards_ppo_baseline[i]) for i in range(n)]}
    )

    resilience_metric_random = ResilienceMetrics(K=5, numberScenarios=1, assemblyIndicatorFuction='harmonic')
    resilience_r = resilience_metric_random.fit(
        disturbancesIndex={0: [disruption_step]},
        Pset={0: [apples_random, gini_random, hunger_random] + [np.cumsum(rewards_random[i]) for i in range(n)]},
        Rset={0: [apples_random_baseline, gini_random_baseline, hunger_random_baseline] + [np.cumsum(rewards_random_baseline[i]) for i in range(n)]}
    )

    resiliencia_by_episode.append(resilience)
    resiliencia_by_episode_random.append(resilience_r)

    for i in range(n):
        total_rewards_ppo[i].append(np.sum(rewards_ppo[i]))
        total_rewards_random[i].append(np.sum(rewards_random[i]))

    print(f"[Episode {episode}] Resilience PPO: {resilience[0]:.3f}, Random: {resilience_r[0]:.3f}")

# Results summary
resiliencia_ppo = [res[0] for res in resiliencia_by_episode]
data_ppo = {
    'resilience': resiliencia_ppo,
    'agent1_reward': total_rewards_ppo[0],
    'agent2_reward': total_rewards_ppo[1],
    'agent3_reward': total_rewards_ppo[2],
    'agent4_reward': total_rewards_ppo[3],
    'len': len_ppo,
    'last_apple': last_apple_ppo
}

plt.boxplot(resiliencia_ppo)
plt.show()

#save_data(data_ppo, 'data_ppo_16x16.pkl')

reward_total = [
    0.25 * (a + b + c + d)
    for a, b, c, d in zip(total_rewards_ppo[0], total_rewards_ppo[1], total_rewards_ppo[2], total_rewards_ppo[3])
]

print(f'[Metrics] Cooperative Resilience: mean = {np.mean(resiliencia_ppo):.3f}, std = {np.std(resiliencia_ppo):.3f}')
print(f'[Metrics] Total Agent Reward: mean = {np.mean(reward_total):.3f}, std = {np.std(reward_total):.3f}')
print(f'[Metrics] Episode Length: mean = {np.mean(len_ppo):.1f}, std = {np.std(len_ppo):.1f}')
print(f'[Metrics] Episodes with last apple consumed: {last_apple_ppo} out of {num_episodes}')
