import os
import json
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from gymnasium import spaces
import wandb
import copy
import hashlib
from collections import Counter

from utils import (
    TransformerPolicy, MODEL_PATH, FINE_TUNED_MODEL_PATH, HIDDEN_DIM, DATA_ROOT, HASH_STR_LEN, MAX_LEN,
    parse_triangle_sequence, is_valid_triangle, save_model, make_env, set_seed, load_model_and_tokenizer,
    validate_triangle_generation, _nan_to_none, init_wandb
)
from data_utils import generate_and_save_dataset


# PPO hyperparameters
PPO_EPOCHS = 2
MINIBATCH_SIZE = 64
GAMMA = 0.99
GAE_LAMBDA = 0.95
CLIP_EPSILON = 0.2
ACTOR_LR = 5e-5
CRITIC_LR = 5e-4
VALUE_LOSS_COEF = 0.5
ENTROPY_COEF = 1e-4 # 0.0001, 0.0005, 0.005, 0.05
KL_COEF = 0.005
UCB_COEF = 0.005
MAX_GRAD_NORM = 0.5
DEBUG = os.environ.get("DEBUG", "False").lower() == "true"

# training hyperparameters
NUM_TRAINING_EPISODES = 2000
EPISODES_PER_COLLECTION = 390
STEPS_PER_COLLECTION = 4096
HEAVY_EVAL_INTERVAL = 25
MAX_STEPS_PER_EPISODE = 16
NUM_EVAL_EPISODES = 100

NUM_EPOCHS = 500
CHECKPOINT_INTERVAL = 100

SEED = 42

# ----------------------
# PPO Agent
# ----------------------
class ValueNetwork(nn.Module):
    def __init__(self, input_dim: int = HIDDEN_DIM, hidden_dim: int = HIDDEN_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        return self.net(feats).squeeze(-1)  # (B,)

class PPOAgent:
    def __init__(self, actor, value_network, device, tokenizer=None):
        self.actor = actor.to(device)
        self.critic = value_network.to(device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=ACTOR_LR)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)
        self.device = device
        self.tokenizer = tokenizer

        # Reference actor for KL divergence
        self.ref_actor = copy.deepcopy(self.actor).to(device)
        self.ref_actor.eval()
        for p in self.ref_actor.parameters():
            p.requires_grad = False
    
    def get_hidden_states(self, obs_tensor, attn_mask=None):
        """Get hidden states from the actor's transformer backbone."""
        B, T = obs_tensor.shape
        if T > self.actor.max_len:
            raise ValueError(f"Sequence length {T} exceeds max_len {self.actor.max_len}")

        # Token + positional embeddings
        pos = torch.arange(T, device=obs_tensor.device).unsqueeze(0)
        x = self.actor.tok_emb(obs_tensor) + self.actor.pos_emb(pos)

        if attn_mask is not None:
            x = self.actor.blocks(x, mask=attn_mask)
        else:
            causal = self.actor.causal_mask[:T, :T]
            x = self.actor.blocks(x, mask=causal)
        x = self.actor.ln_f(x)
        
        return x
    
    def get_features(self, obs_tensor, attn_mask=None):
        """Get features from the actor's transformer backbone."""
        hidden_states = self.get_hidden_states(obs_tensor, attn_mask)
        # Use the last hidden state as features
        return hidden_states[:, -1, :]  # (B, HIDDEN_DIM)
    
    @torch.no_grad()
    def get_value(self, obs_tensor):
        """Get value estimate using actor features + value network."""
        feats = self.get_features(obs_tensor)
        return self.critic(feats)
    
    def prepare_obs(self, obs):
        # Extract only the non-padded part of the observation (same as generate function)
        pad_id = self.tokenizer.pad_id
        actual_length = (np.array(obs) != pad_id).sum()
        actual_sequence = obs[:actual_length]
        
        obs_tensor = torch.tensor(actual_sequence, dtype=torch.long).unsqueeze(0).to(self.device)
        return obs_tensor
            
    @torch.no_grad()
    def select_action(self, obs, temperature=1.0): 
        obs_tensor = self.prepare_obs(obs)
        action_logits = self.actor(obs_tensor)#, attn_mask=attn_mask)
        next_token_logits = action_logits[:, -1, :]
        
        if temperature == 0.0:
            action = torch.argmax(next_token_logits, dim=-1)
            log_prob = torch.log(torch.softmax(next_token_logits, dim=-1)[0, action])
        else:
            scaled_logits = next_token_logits / temperature
            dist = torch.distributions.Categorical(logits=scaled_logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
        return action.item(), log_prob.item(), dist.entropy().item()
    
    def compute_gae(self, rewards, values, next_values, terminateds, truncateds):
        bootstrap_mask = (1.0 - terminateds)   # 1 if not terminated, 0 if terminated

        episode_ends = torch.clamp(terminateds + truncateds, max=1.0)  # 1 at either type of end
        advantages = torch.zeros_like(rewards)
        lastgaelam = 0.0
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + GAMMA * next_values[t] * bootstrap_mask[t] - values[t]
            lastgaelam = delta + GAMMA * GAE_LAMBDA * (1.0 - episode_ends[t]) * lastgaelam
            advantages[t] = lastgaelam
        returns = advantages + values
        return advantages, returns
    
    def _state_key(self, obs):
        """
        Compact, stable key for counting N(s,a) from triangle discovery observations.
        Uses a hash of the observation sequence as the state key.
        """
        try:
            # Convert observation to string and hash it for a stable key
            obs_str = str(obs)
            h = hashlib.sha1(obs_str.encode()).hexdigest()[:8]
        except Exception:
            h = "unknown"
        return h
    
    def update(self, rollouts):
        obs_batch = torch.tensor(np.array([r['obs'] for r in rollouts]), dtype=torch.long).to(self.device)
        actions_batch = torch.tensor(np.array([r['action'] for r in rollouts]), dtype=torch.long).to(self.device)
        log_probs_old_batch = torch.tensor(np.array([r['log_prob'] for r in rollouts]), dtype=torch.float32).to(self.device)
        rewards_batch = torch.tensor(np.array([r['reward'] for r in rollouts]), dtype=torch.float32).to(self.device)
        values_batch = torch.tensor(np.array([r['value'] for r in rollouts]), dtype=torch.float32).to(self.device)
        dones_batch = torch.tensor(np.array([r['done'] for r in rollouts]), dtype=torch.float32).to(self.device)
        terminateds = torch.tensor(np.array([r['terminated'] for r in rollouts]), dtype=torch.float32).to(self.device)
        truncateds = torch.tensor(np.array([r['truncated'] for r in rollouts]), dtype=torch.float32).to(self.device)

        # Compute values for both current and next observations
        with torch.no_grad():
            # Values for current observations
            current_values = self.get_value(obs_batch).squeeze()
            
            # Values for next observations
            next_obs_batch = torch.tensor(np.array([r['next_obs'] for r in rollouts]), dtype=torch.long).to(self.device)
            next_values = self.get_value(next_obs_batch).squeeze()

        advantages, returns = self.compute_gae(rewards_batch, current_values, next_values, terminateds, truncateds)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # UCB bonus calculation
        obs_list = [r['obs'] for r in rollouts]
        sa_counts = Counter()
        for ob, a in zip(obs_list, actions_batch.tolist()):
            k = (self._state_key(ob), int(a))
            sa_counts[k] += 1

        actor_losses = []
        critic_losses = []
        kl_history = []

        for _ in range(PPO_EPOCHS):
            indices = torch.randperm(len(rollouts))
            for start_idx in range(0, len(rollouts), MINIBATCH_SIZE):
                end_idx = start_idx + MINIBATCH_SIZE
                batch_indices = indices[start_idx:end_idx]

                mb_obs = obs_batch[batch_indices]
                mb_actions = actions_batch[batch_indices]
                mb_log_probs_old = log_probs_old_batch[batch_indices]
                mb_returns = returns[batch_indices]
                mb_advantages = advantages[batch_indices]

                # UCB bonus for this minibatch
                if UCB_COEF > 0.0:
                    mb_obs_list = [obs_list[i] for i in batch_indices.tolist()]
                    bonus_vals = []
                    for ob, a in zip(mb_obs_list, mb_actions.tolist()):
                        N = sa_counts[(self._state_key(ob), int(a))]
                        # N >= 1 for sampled pairs; guard anyway
                        b = UCB_COEF * min(1.0, (1.0 / max(1, N)) ** 0.5)
                        bonus_vals.append(b)
                    mb_bonus = torch.tensor(bonus_vals, device=self.device, dtype=torch.float)
                    mb_advantages_aug = mb_advantages + mb_bonus
                else:
                    mb_advantages_aug = mb_advantages

                # Update critic
                mb_feats = self.get_features(mb_obs).detach()
                v = self.critic(mb_feats).squeeze(-1)
                value_loss = VALUE_LOSS_COEF * nn.functional.mse_loss(v, mb_returns)
                self.critic_optimizer.zero_grad(set_to_none=True)
                value_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), MAX_GRAD_NORM)
                self.critic_optimizer.step()
                critic_losses.append(value_loss.item())

                # Update actor
                action_logits = self.actor(mb_obs)
                next_token_logits = action_logits[:, -1, :]
                dist = torch.distributions.Categorical(logits=next_token_logits)

                with torch.no_grad():
                    ref_logits = self.ref_actor(mb_obs)
                    ref_next_token_logits = ref_logits[:, -1, :]
                ref_dist = torch.distributions.Categorical(logits=ref_next_token_logits)
                kl_div = torch.distributions.kl.kl_divergence(dist, ref_dist).mean()
                kl_history.append(kl_div.item())

                current_log_probs = dist.log_prob(mb_actions)
                entropy = dist.entropy().mean()
                ratio = torch.exp(current_log_probs - mb_log_probs_old)
                surr1 = ratio * mb_advantages_aug
                surr2 = torch.clamp(ratio, 1.0 - CLIP_EPSILON, 1.0 + CLIP_EPSILON) * mb_advantages_aug
                actor_loss = -torch.min(surr1, surr2).mean() - ENTROPY_COEF * entropy + KL_COEF * kl_div
                self.actor_optimizer.zero_grad(set_to_none=True)
                actor_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), MAX_GRAD_NORM)
                self.actor_optimizer.step()
                actor_losses.append(actor_loss.item())
                
        avg_actor_loss = np.mean(actor_losses) if actor_losses else 0
        avg_critic_loss = np.mean(critic_losses) if critic_losses else 0
        avg_kl = float(np.mean(kl_history)) if kl_history else 0.0
        
        if avg_kl > 0.1:  # If KL divergence is too high, update reference
            self.ref_actor.load_state_dict(self.actor.state_dict())
            print(f"Updated reference actor due to high KL divergence: {avg_kl:.5f}")
        
        print( f"Avg Actor Loss: {avg_actor_loss:.4f}, "
              f"Avg Critic Loss: {avg_critic_loss:.4f}, "
              f"Avg KL(ref): {avg_kl:.5f}")
        return avg_actor_loss, avg_critic_loss, avg_kl

def train_ppo(envs, agent, num_episodes=1000, wandb_run=None):
    episode_rewards = []
    valid_sequences = 0
    total_episodes = 0
    episode_entropy = []

    # Calculate number of epochs based on episodes_per_collection
    # num_epochs = num_episodes // EPISODES_PER_COLLECTION
    num_epochs = NUM_EPOCHS 
    # if num_episodes % EPISODES_PER_COLLECTION != 0:
    num_epochs += 1

    print(f"Training for {num_epochs} epochs with {EPISODES_PER_COLLECTION} episodes per collection")
    
    for epoch in range(num_epochs):
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")
        
        # Collect episodes for this epoch
        rollouts = []

        for env in envs:
            episodes_collected = 0
            while episodes_collected < EPISODES_PER_COLLECTION // len(envs):
                obs, _ = env.reset()
                episode_reward = 0
                episode_ent = []
                done = False
                step = 0
                
                while not done and step < MAX_STEPS_PER_EPISODE:
                    # Select action
                    action, log_prob, entropy = agent.select_action(obs)
                    obs_tensor = torch.tensor(obs, dtype=torch.long).unsqueeze(0).to(agent.device)
                    value = agent.get_value(obs_tensor).squeeze().item()
                    
                    # Take step
                    next_obs, reward, done, truncated, info = env.step(action)
                    
                    # Store rollout
                    rollouts.append({
                        'obs': obs,
                        'action': action,
                        'log_prob': log_prob,
                        'reward': reward,
                        'done': done,
                        'terminated': done,  # In triangle discovery, done means terminated
                        'truncated': False,  # Triangle discovery doesn't use truncation
                        'value': value,
                        'next_obs': next_obs,
                    })
                    
                    obs = next_obs
                    episode_reward += reward
                    episode_ent.append(entropy)
                    step += 1
                    
                    if done:
                        if reward > 0:
                            valid_sequences += 1
                        break

                episode_entropy.append(np.mean(episode_ent)) # avg entropy per episode
                episode_rewards.append(episode_reward)
                total_episodes += 1
                episodes_collected += 1

        
        # Update policy after collecting episodes for this epoch
        if rollouts:
            print(f"Updating policy with {len(rollouts)} rollouts from {episodes_collected} episodes")
            actor_loss, critic_loss, kl_div = agent.update(rollouts)
            
            # Log training metrics to wandb
            if wandb_run is not None:
                wandb.log({
                    "train/epoch": epoch + 1,
                    "train/actor_loss": actor_loss,
                    "train/critic_loss": critic_loss,
                    "train/kl_divergence": kl_div,
                })
        
        # Log progress
        avg_entropy = np.mean(episode_entropy[-EPISODES_PER_COLLECTION:]) if len(episode_entropy) >= EPISODES_PER_COLLECTION else np.mean(episode_entropy)
        avg_reward = np.mean(episode_rewards[-EPISODES_PER_COLLECTION:]) if len(episode_rewards) >= EPISODES_PER_COLLECTION else np.mean(episode_rewards)
        success_rate = valid_sequences / total_episodes if total_episodes > 0 else 0
        print(f"Epoch {epoch+1}, Avg Reward: {avg_reward:.3f}, Success Rate: {success_rate:.3f}, Avg Entropy: {avg_entropy:.3f}")
        
        # Log training metrics to wandb after every epoch
        if wandb_run is not None:
            wandb.log({
                "train/epoch": epoch + 1,
                "train/avg_reward_epoch": avg_reward,
                "train/success_rate_epoch": success_rate,
                "train/avg_entropy_epoch": avg_entropy,
            })
        
        # Heavy evaluation (like PPO_multiseed)
        if (epoch + 1) % HEAVY_EVAL_INTERVAL == 0:
            print(f"\n--- Heavy Evaluation at Epoch {epoch+1} ---")
            heavy_eval_success_rate = validate_triangle_generation(agent.actor, agent.tokenizer, envs, device=agent.device, num_samples=NUM_EVAL_EPISODES)
            print("--- End Heavy Evaluation ---\n")
            
            # Log heavy evaluation metrics to wandb
            if wandb_run is not None:
                wandb.log({
                    "heavy_eval/epoch": epoch + 1,
                    "heavy_eval/success_rate": heavy_eval_success_rate,
                })


        # # (Opional) Save checkpoint
        if epoch % CHECKPOINT_INTERVAL == 0:
            if wandb_run is not None:
                checkpoint_path = os.path.join(wandb_run.dir, f"epoch_{epoch}.pt")
                save_model(agent.actor, agent.tokenizer, checkpoint_path)
                print(f"Saving checkpoint at epoch {epoch} to {checkpoint_path}")
    
    return episode_rewards


def main():
    # Set seeds for reproducibility
    set_seed(SEED)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(f"Using seed: {SEED}")

    # Ensure dataset exists
    if not os.path.exists(DATA_ROOT):
        os.makedirs(DATA_ROOT, exist_ok=True)
        print(f"Creating dataset in {DATA_ROOT}")
        generate_and_save_dataset(num_graphs=1, fixed_hash_per_graph=True, dataset_name="pretrain")
    
    # Load tokenizer from model file
    if os.path.exists(MODEL_PATH):
        print(f"Loaded pretrained policy from {MODEL_PATH}")

        actor, tokenizer = load_model_and_tokenizer(MODEL_PATH, device=device)
        print(f"Loaded tokenizer with vocab_size={tokenizer.vocab_size}")
        
        # Create value network that uses actor features
        value_network = ValueNetwork(input_dim=HIDDEN_DIM, hidden_dim=HIDDEN_DIM).to(device)
    else:
        raise FileNotFoundError(f"Could not find {MODEL_PATH}")
    
    # Create environment and agent
    envs = []
    for i in range(3):
        env = make_env(tokenizer, i, device)
        envs.append(env)

    agent = PPOAgent(actor=actor, value_network=value_network, device=device, tokenizer=tokenizer)
    
    # Initialize wandb
    wandb_config = {
        "algo": "PPO_UCB",
        "env": "TriangleDiscovery",
        "ppo_epochs": PPO_EPOCHS,
        "minibatch_size": MINIBATCH_SIZE,
        "gamma": GAMMA,
        "gae_lambda": GAE_LAMBDA,
        "clip_epsilon": CLIP_EPSILON,
        "actor_lr": ACTOR_LR,
        "critic_lr": CRITIC_LR,
        "value_loss_coef": VALUE_LOSS_COEF,
        "entropy_coef": ENTROPY_COEF,
        "kl_coef": KL_COEF,
        "ucb_coef": UCB_COEF,
        "max_grad_norm": MAX_GRAD_NORM,
        "num_training_episodes": NUM_TRAINING_EPISODES,
        "steps_per_collection": STEPS_PER_COLLECTION,
        "heavy_eval_interval": HEAVY_EVAL_INTERVAL,
        "num_eval_episodes": NUM_EVAL_EPISODES,
        "max_steps_per_episode": MAX_STEPS_PER_EPISODE,
        "episodes_per_collection": EPISODES_PER_COLLECTION,
        "vocab_size": tokenizer.vocab_size,
        "d_model": 256,
        "n_layer": 4,
        "n_head": 8,
        "dim_ff": 1024,
        "dropout": 0.1,
        "max_len": MAX_LEN,
    }
    if DEBUG:
        wandb_run = None
    else:
        wandb_run = init_wandb(run_name="ppo_ucb_triangle_discovery", config=wandb_config)
    
    # Initial evaluation
    print("Initial evaluation of pretrained policy...")
    initial_success_rate = validate_triangle_generation(agent.actor, agent.tokenizer, envs, device=device, num_samples=NUM_EVAL_EPISODES)
    
    # Train
    print("Starting PPO training...")
    rewards = train_ppo(envs, agent, wandb_run=wandb_run)
    
    # Evaluate trained agent
    print("\nEvaluating trained agent...")
    success_rate = validate_triangle_generation(agent.actor, agent.tokenizer, envs, device=device, num_samples=NUM_EVAL_EPISODES)
    
    # Save fine-tuned model
    save_model(agent.actor, agent.tokenizer, FINE_TUNED_MODEL_PATH.replace('ppo', 'ppo_ucb'))
    print(f"Saved fine-tuned model to {FINE_TUNED_MODEL_PATH.replace('ppo', 'ppo_ucb')}")
    
    # Log model artifact to wandb
    if wandb_run is not None:
        wandb_model_path = os.path.join(wandb_run.dir, "ppo_ucb_triangle_discovery_model.pt")
        save_model(agent.actor, agent.tokenizer, wandb_model_path)
        print(f"Saved model to wandb directory: {wandb_model_path}")

        try:
            artifact = wandb.Artifact("ppo_ucb_triangle_discovery_model", type="model")
            artifact.add_file(FINE_TUNED_MODEL_PATH.replace('ppo', 'ppo_ucb'))
            wandb.log_artifact(artifact)
        except Exception as e:
            print(f"[wandb] log_artifact failed: {e}")
        try:
            wandb.finish()
        except Exception:
            pass

if __name__ == "__main__":
    main()
