import os
import json
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from gymnasium import spaces
import wandb
import copy
import hashlib
from collections import Counter

from utils import (
    TransformerPolicy, MODEL_PATH, FINE_TUNED_MODEL_PATH, HIDDEN_DIM, DATA_ROOT, 
    HASH_STR_LEN, parse_triangle_sequence, is_valid_triangle, MAX_LEN, save_model, make_env, set_seed,
    load_model_and_tokenizer, validate_triangle_generation, _nan_to_none, init_wandb
)
from data_utils import generate_and_save_dataset
from fractal_sampling import generate_dataset as generate_fractal_dataset, obs_to_key

# PPO hyperparameters
PPO_EPOCHS = 2
MINIBATCH_SIZE = 64
GAMMA = 0.99
GAE_LAMBDA = 0.95
CLIP_EPSILON = 0.2
ACTOR_LR = 1e-5  # lower LR for polychromic
CRITIC_LR = 1e-4
VALUE_LOSS_COEF = 0.5
ENTROPY_COEF = 0.0
KL_COEF = 0.005
UCB_COEF = 0.005
MAX_GRAD_NORM = 0.5

# Polychromic hyperparameters
NUM_VINES_AT_STATE = 8 # 16
NUM_LEVELS = 3
POLYCHROME_WINDOW = 1  # Set to 1 for advantage to apply only to current timestep

# training hyperparameters
MAIN_ROLLOUT_LENGTH = 9
NUM_TRAINING_EPISODES = 1000
EPISODES_PER_COLLECTION = 130
STEPS_PER_COLLECTION = 4096
EVAL_INTERVAL = 50
HEAVY_EVAL_INTERVAL = 25
MAX_STEPS_PER_EPISODE = 16
NUM_EVAL_EPISODES = 100

NUM_EPOCHS = 500
CHECKPOINT_INTERVAL = 100

SEED = 42
DEBUG = os.environ.get("DEBUG", "False").lower() == "true"

# --- Disable-dropout helper ---
def disable_dropout(module: nn.Module):
    """Set p=0.0 for every nn.Dropout (incl. ones inside Transformer blocks)."""
    for m in module.modules():
        if isinstance(m, nn.Dropout):
            m.p = 0.0


def cluster_seq_diversity(sequences: list[list[int]], tokenizer) -> float:
    """
    Diversity is defined as # of unique token sequences / # of sequences.
    
    Args:
        sequences: Sequences of triangle vertices (need not be complete)
        tokenizer: Tokenizer for the triangle sequence
    """
    vertices = [tuple(seq) for seq in sequences]
    n_unique = len(set(vertices))
    if n_unique == 1:  # all sequences are the same
        return 0
    return n_unique/len(sequences) # already divided by len(group) in group score

# ----------------------
# PPO Agent
# ----------------------
class ValueNetwork(nn.Module):
    def __init__(self, input_dim: int = HIDDEN_DIM, hidden_dim: int = HIDDEN_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        return self.net(feats).squeeze(-1)  # (B,)

class PolychromicPPOAgent:
    def __init__(self, actor, value_network, device, tokenizer=None):
        self.actor = actor.to(device)
        self.critic = value_network.to(device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=ACTOR_LR)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)
        self.device = device
        self.tokenizer = tokenizer

        # Hard-disable dropout on the live actor
        disable_dropout(self.actor)

        # Reference actor for KL divergence
        self.ref_actor = copy.deepcopy(self.actor).to(device)
        disable_dropout(self.ref_actor)           # make sure ref has p=0.0 too
        self.ref_actor.eval()
        for p in self.ref_actor.parameters():
            p.requires_grad = False

    
    def get_hidden_states(self, obs_tensor, attn_mask=None):
        """Get hidden states from the actor's transformer backbone."""
        B, T = obs_tensor.shape
        if T > self.actor.max_len:
            raise ValueError(f"Sequence length {T} exceeds max_len {self.actor.max_len}")

        # Token + positional embeddings
        pos = torch.arange(T, device=obs_tensor.device).unsqueeze(0)
        x = self.actor.tok_emb(obs_tensor) + self.actor.pos_emb(pos)

        if attn_mask is not None:
            x = self.actor.blocks(x, mask=attn_mask)
        else:
            causal = self.actor.causal_mask[:T, :T]
            x = self.actor.blocks(x, mask=causal)
        x = self.actor.ln_f(x)
        
        return x
    
    def get_features(self, obs_tensor, attn_mask=None):
        """Get features from the actor's transformer backbone."""
        hidden_states = self.get_hidden_states(obs_tensor, attn_mask)
        # Use the last hidden state as features
        return hidden_states[:, -1, :]  # (B, HIDDEN_DIM)
    
    @torch.no_grad()
    def get_value(self, obs_tensor):
        """Get value estimate using actor features + value network."""
        feats = self.get_features(obs_tensor)
        return self.critic(feats)
    
    
    def prepare_obs(self, obs):
        # Extract only the non-padded part of the observation (same as generate function)
        pad_id = self.tokenizer.pad_id
        actual_length = (np.array(obs) != pad_id).sum()
        actual_sequence = obs[:actual_length]
        
        obs_tensor = torch.tensor(actual_sequence, dtype=torch.long).unsqueeze(0).to(self.device)
        # Create attention mask for the actual sequence (all True, like generate function)
        return obs_tensor
    
    def _state_key(self, obs):
        """
        Compact, stable key for counting N(s,a) from triangle discovery observations.
        Uses a hash of the observation sequence as the state key.
        """
        try:
            # Convert observation to string and hash it for a stable key
            obs_str = str(obs)
            h = hashlib.sha1(obs_str.encode()).hexdigest()[:8]
        except Exception:
            h = "unknown"
        return h
            
    @torch.no_grad()
    def select_action(self, obs, temperature=1.0): 
        obs_tensor = self.prepare_obs(obs)
        action_logits = self.actor(obs_tensor)#, attn_mask=attn_mask)
        next_token_logits = action_logits[:, -1, :]
        
        if temperature == 0.0:
            action = torch.argmax(next_token_logits, dim=-1)
            log_prob = torch.log(torch.softmax(next_token_logits, dim=-1)[0, action])
        else:
            scaled_logits = next_token_logits / temperature
            dist = torch.distributions.Categorical(logits=scaled_logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
        return action.item(), log_prob.item(), dist.entropy().item()
    
    def compute_gae(self, rewards, values, next_values, terminateds, truncateds):
        bootstrap_mask = (1.0 - terminateds)   # 1 if not terminated, 0 if terminated

        episode_ends = torch.clamp(terminateds + truncateds, max=1.0)  # 1 at either type of end
        advantages = torch.zeros_like(rewards)
        lastgaelam = 0.0
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + GAMMA * next_values[t] * bootstrap_mask[t] - values[t]
            lastgaelam = delta + GAMMA * GAE_LAMBDA * (1.0 - episode_ends[t]) * lastgaelam
            advantages[t] = lastgaelam
        returns = advantages + values
        return advantages, returns

    def update_from_dataset(self, on_policy_dataset, ppo_epochs, minibatch_size, value_loss_coef, polychrome_window=5):
        """
        Update the agent using polychromic advantage calculation from fractal sampling dataset.
        Based on the BabyAI implementation.
        """
        self.actor.train()
        self.critic.train()
        self.ref_actor.eval()

        # 1. Flatten dataset from dictionary structure
        all_obs, all_actions, all_old_logps, all_rewards, all_dones, all_next_obs = [], [], [], [], [], []
        traj_indices = []
        for trajectories in on_policy_dataset.values():
            for traj in trajectories:
                start_idx = len(all_obs)
                for step in traj:
                    all_obs.append(step['observation'])
                    all_actions.append(step['action'])
                    all_old_logps.append(step['log_prob'])
                    all_rewards.append(step['reward'])
                    all_dones.append(step['done'])
                    all_next_obs.append(step['next_observation'])
                traj_indices.append((start_idx, len(all_obs)))
        
        if not all_obs: 
            return 0.0, 0.0, 0.0, []

        # 2. Compute state-action counts for UCB bonus
        sa_counts = Counter()
        for ob, a in zip(all_obs, all_actions):
            k = (self._state_key(ob), int(a))
            sa_counts[k] += 1

        # 3. Compute GAE advantages trajectory by trajectory
        obs_batch = torch.tensor(np.array(all_obs), dtype=torch.long).to(self.device)
        next_obs_batch = torch.tensor(np.array(all_next_obs), dtype=torch.long).to(self.device)
        with torch.no_grad():
            values_batch = self.get_value(obs_batch).squeeze(-1)
            next_values_batch = self.get_value(next_obs_batch).squeeze(-1)
        
        rewards_batch = torch.tensor(all_rewards, device=self.device, dtype=torch.float)
        dones_batch = torch.tensor(all_dones, device=self.device, dtype=torch.float)
        terminateds = dones_batch  # In triangle discovery, done means terminated
        truncateds = torch.zeros_like(dones_batch)  # Triangle discovery doesn't use truncation
        
        advantages_list, returns_list = [], []
        for start, end in traj_indices:
            adv, ret = self.compute_gae(
                rewards_batch[start:end], values_batch[start:end], next_values_batch[start:end],
                terminateds[start:end], truncateds[start:end]
            )
            advantages_list.append(adv)
            returns_list.append(ret)
        
        advantages_batch = torch.cat(advantages_list)
        returns_batch = torch.cat(returns_list)
        actions_batch = torch.tensor(all_actions, device=self.device, dtype=torch.long)
        log_probs_old_batch = torch.tensor(all_old_logps, device=self.device)
        advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8)

        # 4. Polychromic Advantage Calculation & Augmentation
        aug_obs, aug_actions, aug_logps, aug_adv, aug_returns = [], [], [], [], []

        avg_group_diversity = []
        for vine_state_key, trajectories in on_policy_dataset.items():
            if len(trajectories) < 4: 
                continue
            # Sample groups of 4 trajectories from the same vine state
            groups = [random.sample(trajectories, 4) for _ in range(4)]
            group_values = []
            for group in groups:
                reward = sum(sum(s['reward'] for s in t) for t in group)
                group_sequences = [t[-1]['env_state'] for t in group]

                # Calculate diversity for trajectories from the same vine state
                diversity = cluster_seq_diversity(group_sequences, self.tokenizer)
                group_values.append((reward / len(group)) * diversity) # TODO: should be dividing by len(group) here?
                if DEBUG:
                    print(f"diversity: {diversity} group_value: {(reward / len(group)) * diversity}")
                avg_group_diversity.append(diversity)
            
            if not group_values: 
                continue
            baseline = sum(group_values) / len(group_values)

            for group, gval in zip(groups, group_values):
                advantage = float(gval - baseline)
                for traj in group:
                    # Find the start index for this trajectory in the vine
                    start_idx = 0  # For triangle discovery, we start from the beginning
                    # NOTE: only append single start_idx=0 for each vine state
                    for t in range(start_idx, start_idx + polychrome_window): #, len(traj))):
                        step = traj[t]
                        aug_obs.append(step['observation'])
                        aug_actions.append(step['action'])
                        aug_logps.append(step['log_prob'])
                        aug_adv.append(advantage)
                        with torch.no_grad():
                            obs_tensor = torch.tensor(step['observation'], dtype=torch.long).unsqueeze(0).to(self.device)
                            v = self.get_value(obs_tensor).item()
                            aug_returns.append(v)

        # 4. Combine base PPO data and augmented Polychromic data
        # Create combined observation list for UCB bonus calculation
        combined_obs_list = all_obs.copy()
        if aug_obs:
            combined_obs_list.extend(aug_obs)
            aug_obs_batch = torch.tensor(np.array(aug_obs), dtype=torch.long).to(self.device)
            obs_batch = torch.cat([obs_batch, aug_obs_batch], dim=0)
            
            actions_batch = torch.cat([actions_batch, torch.tensor(aug_actions, device=self.device, dtype=torch.long)])
            log_probs_old_batch = torch.cat([log_probs_old_batch, torch.tensor(aug_logps, device=self.device)])
            returns_batch = torch.cat([returns_batch, torch.tensor(aug_returns, device=self.device)])
            aug_adv_tensor = torch.tensor(aug_adv, device=self.device, dtype=torch.float)
            aug_adv_tensor = (aug_adv_tensor - aug_adv_tensor.mean()) / (aug_adv_tensor.std() + 1e-8)
            advantages_batch = torch.cat([advantages_batch, aug_adv_tensor])
        
        # 5. Perform PPO updates
        actor_losses, critic_losses, kl_hist, episode_ent = [], [], [], []
        n = actions_batch.size(0)
        for _ in range(ppo_epochs):
            indices = torch.randperm(n)
            for start in range(0, n, minibatch_size):
                mb_idx = indices[start:start+minibatch_size]
                mb_obs = obs_batch[mb_idx]
                mb_actions = actions_batch[mb_idx]
                mb_old_logp = log_probs_old_batch[mb_idx]
                mb_returns = returns_batch[mb_idx]
                mb_adv = advantages_batch[mb_idx]

                # UCB bonus calculation for this minibatch
                if UCB_COEF > 0.0:
                    mb_idx_list = mb_idx.tolist()
                    bonus_vals = []
                    for j, i in enumerate(mb_idx_list):
                        ob_i = combined_obs_list[i] 
                        a_i = int(mb_actions[j].item())
                        N = sa_counts[(self._state_key(ob_i), a_i)] 
                        b = UCB_COEF * min(1.0, (1.0 / max(1, N)) ** 0.5)
                        bonus_vals.append(b)
                    mb_bonus = torch.tensor(bonus_vals, device=self.device, dtype=torch.float)
                    mb_adv_aug = mb_adv + mb_bonus
                else:
                    mb_adv_aug = mb_adv

                # Update critic
                mb_feats = self.get_features(mb_obs).detach()
                v = self.critic(mb_feats).squeeze(-1)
                v_loss = value_loss_coef * nn.functional.mse_loss(v, mb_returns)
                self.critic_optimizer.zero_grad(set_to_none=True)
                v_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), MAX_GRAD_NORM)
                self.critic_optimizer.step()
                critic_losses.append(v_loss.item())

                # Update actor
                action_logits = self.actor(mb_obs)
                next_token_logits = action_logits[:, -1, :]
                dist = torch.distributions.Categorical(logits=next_token_logits)

                with torch.no_grad():
                    ref_logits = self.ref_actor(mb_obs)
                    ref_next_token_logits = ref_logits[:, -1, :]
                ref_dist = torch.distributions.Categorical(logits=ref_next_token_logits)
                kl_div = torch.distributions.kl.kl_divergence(dist, ref_dist).mean()
                kl_hist.append(kl_div.item())

                current_log_probs = dist.log_prob(mb_actions)
                entropy = dist.entropy().mean()
                ratio = torch.exp(current_log_probs - mb_old_logp)
                surr1 = ratio * mb_adv_aug
                surr2 = torch.clamp(ratio, 1.0 - CLIP_EPSILON, 1.0 + CLIP_EPSILON) * mb_adv_aug
                actor_loss = -torch.min(surr1, surr2).mean() - ENTROPY_COEF * entropy + KL_COEF * kl_div
                self.actor_optimizer.zero_grad(set_to_none=True)
                actor_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), MAX_GRAD_NORM)
                self.actor_optimizer.step()
                actor_losses.append(actor_loss.item())
                episode_ent.append(entropy.item())
        
        avg_actor = float(np.mean(actor_losses)) if actor_losses else 0.0
        avg_critic = float(np.mean(critic_losses)) if critic_losses else 0.0
        avg_kl = float(np.mean(kl_hist)) if kl_hist else 0.0
        print(f"[Poly-PPO] Update -> actor: {avg_actor:.4f} | critic: {avg_critic:.4f} | KL(ref): {avg_kl:.5f}")

        # TODO
        if avg_kl > 0.1:  # If KL divergence is too high, update reference
            self.ref_actor.load_state_dict(self.actor.state_dict())
            print(f"Updated reference actor due to high KL divergence: {avg_kl:.5f}")
        
        return avg_actor, avg_critic, avg_kl, avg_group_diversity, episode_ent

def train_polychromic_ppo(envs, agent, num_episodes=1000, wandb_run=None):
    """
    Train using polychromic PPO with fractal sampling dataset generation.
    """
    episode_rewards = []
    valid_sequences = 0
    total_episodes = 0
    episode_entropy = []

    num_epochs = NUM_EPOCHS
    
    print(f"Training for {num_epochs} epochs with {EPISODES_PER_COLLECTION} episodes per collection")
    
    for epoch in range(num_epochs):
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")

        datasets = []
        for env in envs:    
            on_policy_dataset = generate_fractal_dataset(
                policy=agent.actor, 
                tokenizer=agent.tokenizer,
                env=env, 
                num_vines_at_state=NUM_VINES_AT_STATE,
                num_levels=NUM_LEVELS, 
                main_rollout_max_steps=MAIN_ROLLOUT_LENGTH, 
                device=agent.device,
            )
            print(f"Generated dataset with {sum(len(v) for v in on_policy_dataset.values())} trajectories")
            datasets.append(on_policy_dataset)

        # combine datasets from all environments
        on_policy_dataset = {}
        for d in datasets:
            for _, trajectories in d.items():
                for traj in trajectories: 
                    if not traj: 
                        continue 
                    vine_state_key = obs_to_key(traj[0]['observation'])
                    on_policy_dataset.setdefault(vine_state_key, []).append(traj)

        if on_policy_dataset:
            # Update using polychromic advantage
            actor_loss, critic_loss, kl_div, group_div_list, episode_ent = agent.update_from_dataset(
                on_policy_dataset=on_policy_dataset,
                ppo_epochs=PPO_EPOCHS,
                minibatch_size=MINIBATCH_SIZE,
                value_loss_coef=VALUE_LOSS_COEF,
                polychrome_window=POLYCHROME_WINDOW
            )
            
            # Calculate episode metrics from dataset
            episode_reward = []
            episode_valid_sequences = 0
            n_trajectories = 0

            for trajectories in on_policy_dataset.values():
                n_trajectories += len(trajectories)
                for traj in trajectories:
                    traj_reward = sum(step['reward'] for step in traj)
                    episode_reward.append(traj_reward)
                    if traj_reward > 0:
                        episode_valid_sequences += 1
            
            episode_rewards.append(np.mean(episode_reward))  # get reward per trajectory
            valid_sequences += episode_valid_sequences
            total_episodes += n_trajectories
            episode_entropy.append(np.mean(episode_ent)) # avg entropy per episode
            
            # Log training metrics to wandb
            if wandb_run is not None:
                avg_group_diversity = float(np.mean(group_div_list)) if group_div_list else 0.0
                wandb.log({
                    "train/actor_loss": actor_loss,
                    "train/critic_loss": critic_loss,
                    "train/kl_divergence": kl_div,
                    "train/average_group_diversity": avg_group_diversity,
                })
        
        # Log progress
        avg_reward = np.mean(episode_rewards[-EPISODES_PER_COLLECTION:]) if len(episode_rewards) >= EPISODES_PER_COLLECTION else np.mean(episode_rewards)
        success_rate = valid_sequences / total_episodes if total_episodes > 0 else 0
        print(f"Epoch {epoch+1}, Avg Reward: {avg_reward:.3f}, Success Rate: {success_rate:.3f}, Avg Entropy: {np.mean(episode_entropy):.3f}")
        
        # Log training metrics to wandb after every epoch
        if wandb_run is not None:
            wandb.log({
                "train/epoch": epoch + 1,
                "train/avg_reward_epoch": avg_reward,
                "train/valid_sequences": episode_valid_sequences,
                "train/success_rate_epoch": success_rate,
                "train/avg_entropy_epoch": np.mean(episode_entropy),
            })
        
        # Heavy evaluation (like PPO_multiseed)
        if (epoch + 1) % HEAVY_EVAL_INTERVAL == 0:
            print(f"\n--- Heavy Evaluation at Epoch {epoch+1} ---")
            heavy_eval_success_rate = validate_triangle_generation(agent.actor, agent.tokenizer, envs, device=agent.device, num_samples=NUM_EVAL_EPISODES)
            print("--- End Heavy Evaluation ---\n")
            
            # Log heavy evaluation metrics to wandb
            if wandb_run is not None:
                wandb.log({
                    "heavy_eval/epoch": epoch + 1,
                    "heavy_eval/success_rate": heavy_eval_success_rate,
                    "train/avg_entropy_epoch": np.mean(episode_entropy),
                })
    
    return episode_rewards


def main():
    set_seed(SEED)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(f"Using seed: {SEED}")

    # Ensure dataset exists
    if not os.path.exists(DATA_ROOT):
        os.makedirs(DATA_ROOT, exist_ok=True)
        print(f"Creating dataset in {DATA_ROOT}")
        generate_and_save_dataset(num_graphs=1, fixed_hash_per_graph=True)
    
    # Load tokenizer from model file
    if os.path.exists(MODEL_PATH):
        print(f"Loaded pretrained policy from {MODEL_PATH}")

        actor, tokenizer = load_model_and_tokenizer(MODEL_PATH, device=device)
        print(f"Loaded tokenizer with vocab_size={tokenizer.vocab_size}")
        
        # Create value network that uses actor features
        value_network = ValueNetwork(input_dim=HIDDEN_DIM, hidden_dim=HIDDEN_DIM).to(device)
    else:
        assert False, f"Model path {MODEL_PATH} does not exist."
        raise FileNotFoundError(f"Could not find {MODEL_PATH}")
    
    # Create environment and agent
    envs = []
    for i in range(3):
        env = make_env(tokenizer, i, device)
        envs.append(env)
    
    agent = PolychromicPPOAgent(actor=actor, value_network=value_network, device=device, tokenizer=tokenizer)
    
    # Initialize wandb
    wandb_config = {
        "algo": "PolychromicPPO_UCB",
        "env": "TriangleDiscovery",
        "ppo_epochs": PPO_EPOCHS,
        "minibatch_size": MINIBATCH_SIZE,
        "gamma": GAMMA,
        "gae_lambda": GAE_LAMBDA,
        "clip_epsilon": CLIP_EPSILON,
        "actor_lr": ACTOR_LR,
        "critic_lr": CRITIC_LR,
        "value_loss_coef": VALUE_LOSS_COEF,
        "entropy_coef": ENTROPY_COEF,
        "kl_coef": KL_COEF,
        "ucb_coef": UCB_COEF,
        "max_grad_norm": MAX_GRAD_NORM,
        "num_training_episodes": NUM_TRAINING_EPISODES,
        "steps_per_collection": STEPS_PER_COLLECTION,
        "eval_interval": EVAL_INTERVAL,
        "heavy_eval_interval": HEAVY_EVAL_INTERVAL,
        "num_eval_episodes": NUM_EVAL_EPISODES,
        "max_steps_per_episode": MAX_STEPS_PER_EPISODE,
        "episodes_per_collection": EPISODES_PER_COLLECTION,
        "vocab_size": tokenizer.vocab_size,
        "d_model": 256,
        "n_layer": 4,
        "n_head": 8,
        "dim_ff": 1024,
        "dropout": 0.1,
        "max_len": MAX_LEN,
    }
    if DEBUG:
        wandb_run = None
    else:
        wandb_run = init_wandb(run_name="polychromic_ppo_ucb_triangle_discovery", config=wandb_config)
    
    POLYCHROMIC_FINE_TUNED_MODEL_PATH = FINE_TUNED_MODEL_PATH.replace("ppo", "polychromic_ppo_ucb")

    # Initial evaluation
    print("Initial evaluation of pretrained policy...")
    initial_success_rate = validate_triangle_generation(agent.actor, agent.tokenizer, envs, device=device, num_samples=NUM_EVAL_EPISODES)
    
    # Train
    print("Starting Polychromic PPO training...")
    rewards = train_polychromic_ppo(envs, agent, num_episodes=NUM_TRAINING_EPISODES, wandb_run=wandb_run)
    
    # Evaluate trained agent
    print("\nEvaluating trained agent...")
    success_rate = validate_triangle_generation(agent.actor, agent.tokenizer, envs, device=device, num_samples=NUM_EVAL_EPISODES)
    
    # Also save to wandb run directory
    if wandb_run is not None:
        wandb_model_path = os.path.join(wandb_run.dir, "polychromic_ppo_ucb_triangle_discovery_model.pt")
        save_model(agent.actor, agent.tokenizer, wandb_model_path)
        print(f"Saved model to wandb directory: {wandb_model_path}")
        
        # Log model artifact to wandb
        try:
            artifact = wandb.Artifact("polychromic_ppo_ucb_triangle_discovery_model", type="model")
            artifact.add_file(POLYCHROMIC_FINE_TUNED_MODEL_PATH)
            wandb.log_artifact(artifact)
        except Exception as e:
            print(f"[wandb] log_artifact failed: {e}")
        try:
            wandb.finish()
        except Exception:
            pass

    # Save fine-tuned model
    save_model(agent.actor, agent.tokenizer, POLYCHROMIC_FINE_TUNED_MODEL_PATH)
    print(f"Saved fine-tuned model to {POLYCHROMIC_FINE_TUNED_MODEL_PATH}")
    

if __name__ == "__main__":
    main()
