import os
import json
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
from gymnasium import spaces
import wandb
import copy
import hashlib
from collections import Counter

from utils import (
    TransformerPolicy, MODEL_PATH, FINE_TUNED_MODEL_PATH, HIDDEN_DIM, DATA_ROOT, 
    HASH_STR_LEN, TriangleTokenizer, dataset_dir, load_json, T_TRIANGLES,
    parse_triangle_sequence, is_valid_triangle, MAX_LEN, save_model, make_env, set_seed, load_model_and_tokenizer,
    validate_triangle_generation, _nan_to_none, init_wandb
)
from data_utils import generate_and_save_dataset

LEARNING_RATE = 5e-5
CRITIC_LR = 5e-4
GAMMA = 1.0
VALUE_LOSS_COEF = 0.5
ENTROPY_COEF = 0.0
KL_COEF = 0.01
UCB_COEF = 0.005
MAX_GRAD_NORM = 0.5
DEBUG = os.environ.get("DEBUG", "False").lower() == "true"

# Training hyperparameters
NUM_TRAINING_EPISODES = 1000
EPISODES_PER_UPDATE = 50
HEAVY_EVAL_INTERVAL = 25
MAX_STEPS_PER_EPISODE = 16
NUM_EVAL_EPISODES = 100

NUM_EPOCHS = 500
SEED = 42

# -----------------------------
# Entropy helpers
# -----------------------------
def cat_kl_from_logits(p_logits: torch.Tensor, q_logits: torch.Tensor) -> torch.Tensor:
    """KL(softmax(p_logits) || softmax(q_logits)) per row."""
    p_log = torch.log_softmax(p_logits, dim=-1)
    q_log = torch.log_softmax(q_logits, dim=-1)
    p = p_log.exp()
    return (p * (p_log - q_log)).sum(-1)

def cat_entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
    logp = torch.log_softmax(logits, dim=-1)
    p = logp.exp()
    return -(p * logp).sum(-1)

# -----------------------------
# Critic
# -----------------------------
class ValueNetwork(nn.Module):
    def __init__(self, input_dim: int = HIDDEN_DIM, hidden_dim: int = HIDDEN_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        return self.net(feats).squeeze(-1)  # (B,)

# ----------------------
# REINFORCE Agent
# ----------------------
class REINFORCEAgent:
    def __init__(self, actor, value_network, device, tokenizer=None):
        self.actor = actor.to(device)
        self.critic = value_network.to(device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=LEARNING_RATE)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)
        self.device = device
        self.tokenizer = tokenizer

        # Reference actor for KL divergence
        self.ref_actor = copy.deepcopy(self.actor).to(device)
        self.ref_actor.eval()
        for p in self.ref_actor.parameters():
            p.requires_grad = False
    
    def get_hidden_states(self, obs_tensor, attn_mask=None):
        """Get hidden states from the actor's transformer backbone."""
        B, T = obs_tensor.shape
        if T > self.actor.max_len:
            raise ValueError(f"Sequence length {T} exceeds max_len {self.actor.max_len}")

        # Token + positional embeddings
        pos = torch.arange(T, device=obs_tensor.device).unsqueeze(0)
        x = self.actor.tok_emb(obs_tensor) + self.actor.pos_emb(pos)

        if attn_mask is not None:
            x = self.actor.blocks(x, mask=attn_mask)
        else:
            causal = self.actor.causal_mask[:T, :T]
            x = self.actor.blocks(x, mask=causal)
        x = self.actor.ln_f(x)
        
        return x
    
    def get_features(self, obs_tensor, attn_mask=None):
        """Get features from the actor's transformer backbone."""
        hidden_states = self.get_hidden_states(obs_tensor, attn_mask)
        # Use the last hidden state as features
        return hidden_states[:, -1, :]  # (B, HIDDEN_DIM)
    
    @torch.no_grad()
    def get_value(self, obs_tensor):
        """Get value estimate using actor features + value network."""
        feats = self.get_features(obs_tensor)
        return self.critic(feats)
    
    def prepare_obs(self, obs):
        # Extract only the non-padded part of the observation (same as generate function)
        pad_id = self.tokenizer.pad_id
        actual_length = (np.array(obs) != pad_id).sum()
        actual_sequence = obs[:actual_length]
        
        obs_tensor = torch.tensor(actual_sequence, dtype=torch.long).unsqueeze(0).to(self.device)
        return obs_tensor
            
    @torch.no_grad()
    def select_action(self, obs, temperature=1.0): 
        obs_tensor = self.prepare_obs(obs)
        action_logits = self.actor(obs_tensor)#, attn_mask=attn_mask)
        next_token_logits = action_logits[:, -1, :]
        
        if temperature == 0.0:
            action = torch.argmax(next_token_logits, dim=-1)
            log_prob = torch.log(torch.softmax(next_token_logits, dim=-1)[0, action])
        else:
            scaled_logits = next_token_logits / temperature
            dist = torch.distributions.Categorical(logits=scaled_logits)
            action = dist.sample()
            log_prob = dist.log_prob(action)
        return action.item(), log_prob.item(), dist.entropy().item()
    
    def _state_key(self, obs):
        """
        Compact, stable key for counting N(s,a) from triangle discovery observations.
        Uses a hash of the observation sequence as the state key.
        """
        try:
            # Convert observation to string and hash it for a stable key
            obs_str = str(obs)
            h = hashlib.sha1(obs_str.encode()).hexdigest()[:8]
        except Exception:
            h = "unknown"
        return h
    
    def update(self, rollouts):
        """
        Update the agent using REINFORCE with a baseline and UCB exploration bonus.
        """
        obs_list = [r['obs'] for r in rollouts]
        actions_batch = torch.tensor(np.array([r['action'] for r in rollouts]), dtype=torch.long).to(self.device)
        rewards_batch = torch.tensor(np.array([r['reward'] for r in rollouts]), dtype=torch.float32).to(self.device)
        dones_batch = torch.tensor(np.array([r['done'] for r in rollouts]), dtype=torch.float32).to(self.device)

        # sa counts for UCB bonus
        sa_counts = Counter()
        for ob, a in zip(obs_list, actions_batch.tolist()):
            k = (self._state_key(ob), int(a))
            sa_counts[k] += 1

        # --- 1. Calculate Monte Carlo Returns (G_t) ---
        returns = torch.zeros_like(rewards_batch)
        running_return = 0.0
        for t in reversed(range(len(rewards_batch))):
            # If the episode ended at this step, the return starts from this reward
            # Otherwise, it's the reward + discounted future return
            running_return = rewards_batch[t] + GAMMA * running_return * (1.0 - dones_batch[t])
            returns[t] = running_return

        # --- 2. Calculate Advantages (A_t = G_t - V(s_t)) ---
        obs_batch = torch.tensor(np.array(obs_list), dtype=torch.long).to(self.device)
        feats = self.get_features(obs_batch)
        values = self.critic(feats).squeeze()
        advantages = returns - values

        # Normalize advantages for stability
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # ---- UCB-style bonus on advantage (entropy proxy) ----
        # Same shape as `advantages`. We don't backprop through the bonus.
        if UCB_COEF > 0.0:
            # Per-(s,a) bonus: λ_ucb * min(1, sqrt(1 / N(s,a)))
            with torch.no_grad():
                ucb_bonus = []
                for ob, a in zip(obs_list, actions_batch.tolist()):
                    N = sa_counts[(self._state_key(ob), int(a))]
                    # N >= 1 for sampled pairs; guard anyway
                    b = UCB_COEF * min(1.0, (1.0 / max(1, N)) ** 0.5)
                    ucb_bonus.append(b)
                ucb_bonus = torch.tensor(ucb_bonus, device=self.device, dtype=torch.float)
            advantages = advantages + ucb_bonus
        else:
            ucb_bonus = None  # for logging/printing

        # --- 3. Critic Update ---
        # Train the critic to predict the Monte Carlo returns
        v_loss = VALUE_LOSS_COEF * nn.functional.mse_loss(values, returns)
        self.critic_optimizer.zero_grad(set_to_none=True)
        v_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), MAX_GRAD_NORM)
        self.critic_optimizer.step()

        # --- 4. Actor Update ---
        logits = self.actor(obs_batch)
        next_token_logits = logits[:, -1, :]  # (B, vocab_size)
        dist = torch.distributions.Categorical(logits=next_token_logits)
        logp = dist.log_prob(actions_batch)
        ent = dist.entropy().mean()

        with torch.no_grad():
            ref_logits = self.ref_actor(obs_batch)
            ref_next_token_logits = ref_logits[:, -1, :]
        kl = cat_kl_from_logits(next_token_logits, ref_next_token_logits).mean()

        # REINFORCE objective with KL regularization
        a_loss = -(logp * advantages.detach()).mean() - ENTROPY_COEF * ent + KL_COEF * kl

        self.actor_optimizer.zero_grad(set_to_none=True)
        a_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), MAX_GRAD_NORM)
        self.actor_optimizer.step()

        print(
            f"[REINFORCE] Update -> actor: {a_loss.item():.4f} | "
            f"critic: {v_loss.item():.4f} | KL(ref): {kl.item():.5f}"
            + (f" | UCB(avg): {float(ucb_bonus.mean().item()):.4f}" if ucb_bonus is not None else "")
        )
        return float(a_loss.item()), float(v_loss.item())

def train_reinforce(envs, agent, num_episodes=1000, wandb_run=None):
    """Train the REINFORCE agent using the same structure as reinforce_baseline.py."""
    episode_rewards = []
    valid_sequences = 0
    total_episodes = 0
    episode_entropy = []
    
    num_epochs = NUM_EPOCHS    

    print(f"Training for {num_epochs} epochs with {EPISODES_PER_UPDATE} episodes per update")
    
    for epoch in range(num_epochs):
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")
        
        # Collect rollouts for this update (matching baseline structure)
        rollouts = []
        for env in envs:
            episodes_collected = 0
            while episodes_collected < EPISODES_PER_UPDATE // len(envs):
                obs, _ = env.reset()
                episode_reward = 0
                episode_ent = []
                done = False
                step = 0
                    
                while not done and step < MAX_STEPS_PER_EPISODE:
                    action, logp, entropy = agent.select_action(obs, temperature=1.0)
                    obs_tensor = torch.tensor(obs, dtype=torch.long).unsqueeze(0).to(agent.device)

                    next_obs, reward, terminated, truncated, info = env.step(action)
                    
                    rollouts.append({
                        "obs": obs,
                        "action": action,
                        # "log_prob": logp,
                        "reward": reward,
                        "done": bool(terminated or truncated),
                        # "next_obs": next_obs,
                    })
                    
                    obs = next_obs
                    episode_reward += reward
                    episode_ent.append(entropy)
                    step += 1
                    
                    if terminated:
                        if reward > 0:
                            valid_sequences += 1
                        break
                
                episode_rewards.append(episode_reward)
                episode_entropy.append(np.mean(episode_ent))
                total_episodes += 1
                episodes_collected += 1
        
        # Update policy after collecting rollouts for this update
        if rollouts:
            print(f"Updating policy with {len(rollouts)} rollouts from {episodes_collected} episodes")
            actor_loss, critic_loss = agent.update(rollouts)
            
            # Log training metrics to wandb
            if wandb_run is not None:
                wandb.log({
                    "train/epoch": epoch + 1,
                    "train/actor_loss": _nan_to_none(actor_loss),
                    "train/critic_loss": _nan_to_none(critic_loss),
                })
        
        # Log progress
        avg_entropy = np.mean(episode_entropy[-EPISODES_PER_UPDATE:]) if len(episode_entropy) >= EPISODES_PER_UPDATE else np.mean(episode_entropy)
        avg_reward = np.mean(episode_rewards[-EPISODES_PER_UPDATE:]) if len(episode_rewards) >= EPISODES_PER_UPDATE else np.mean(episode_rewards)
        success_rate = valid_sequences / total_episodes if total_episodes > 0 else 0
        print(f"Epoch {epoch+1}, Avg Reward: {avg_reward:.3f}, Success Rate: {success_rate:.3f}, Avg Entropy: {avg_entropy:.3f}")
        
        # Log training metrics to wandb after every update
        if wandb_run is not None:
            wandb.log({
                "train/epoch": epoch + 1,
                "train/avg_reward_update": avg_reward,
                "train/success_rate_update": success_rate,
                "train/avg_entropy_update": avg_entropy,
            })
        
        # Heavy evaluation
        if (epoch + 1) % HEAVY_EVAL_INTERVAL == 0:
            print(f"\n--- Heavy Evaluation at Epoch {epoch+1} ---")
            heavy_eval_success_rate = validate_triangle_generation(agent.actor, agent.tokenizer, envs, device=agent.device, num_samples=NUM_EVAL_EPISODES)
            print("--- End Heavy Evaluation ---\n")
            
            # Log heavy evaluation metrics to wandb
            if wandb_run is not None:
                wandb.log({
                    "heavy_eval/epoch": epoch + 1,
                    "heavy_eval/success_rate": heavy_eval_success_rate,
                })
    
    return episode_rewards


def main():
    # Set seeds for reproducibility
    set_seed(SEED)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(f"Using seed: {SEED}")

    # Ensure dataset exists
    if not os.path.exists(DATA_ROOT):
        os.makedirs(DATA_ROOT, exist_ok=True)
        print(f"Creating dataset in {DATA_ROOT}")
        generate_and_save_dataset(num_graphs=1, fixed_hash_per_graph=True, dataset_name="pretrain")
    
    # Load tokenizer from model file
    if os.path.exists(MODEL_PATH):
        print(f"Loaded pretrained policy from {MODEL_PATH}")

        actor, tokenizer = load_model_and_tokenizer(MODEL_PATH, device=device)
        print(f"Loaded tokenizer with vocab_size={tokenizer.vocab_size}")
        
        # Create value network that uses actor features
        value_network = ValueNetwork(input_dim=HIDDEN_DIM, hidden_dim=HIDDEN_DIM).to(device)
    else:
        raise FileNotFoundError(f"Could not find {MODEL_PATH}")
    
    # Create environment and agent
    envs = []
    for i in range(3):
        env = make_env(tokenizer, i, device)
        envs.append(env)

    agent = REINFORCEAgent(actor=actor, value_network=value_network, device=device, tokenizer=tokenizer)
    
    # Initialize wandb
    wandb_config = {
        "algo": "REINFORCE_UCB",
        "env": "TriangleDiscovery",
        "actor_lr": LEARNING_RATE,
        "critic_lr": CRITIC_LR,
        "gamma": GAMMA,
        "value_loss_coef": VALUE_LOSS_COEF,
        "entropy_coef": ENTROPY_COEF,
        "kl_coef": KL_COEF,
        "ucb_coef": UCB_COEF,
        "max_grad_norm": MAX_GRAD_NORM,
        "num_training_episodes": NUM_TRAINING_EPISODES,
        "episodes_per_update": EPISODES_PER_UPDATE,
        "heavy_eval_interval": HEAVY_EVAL_INTERVAL,
        "num_eval_episodes": NUM_EVAL_EPISODES,
        "max_steps_per_episode": MAX_STEPS_PER_EPISODE,
        "vocab_size": tokenizer.vocab_size,
        "d_model": 256,
        "n_layer": 4,
        "n_head": 8,
        "dim_ff": 1024,
        "dropout": 0.1,
        "max_len": MAX_LEN,
    }
    if DEBUG:
        wandb_run = None
    else:
        wandb_run = init_wandb(run_name="reinforce_ucb_triangle_discovery", config=wandb_config)
    
    # Initial evaluation
    print("Initial evaluation of pretrained policy...")
    initial_success_rate = validate_triangle_generation(agent.actor, agent.tokenizer, envs, device=device, num_samples=NUM_EVAL_EPISODES)
    
    # Train
    print("Starting REINFORCE training...")
    rewards = train_reinforce(envs, agent, num_episodes=NUM_TRAINING_EPISODES, wandb_run=wandb_run)
    
    # Evaluate trained agent
    print("\nEvaluating trained agent...")
    success_rate = validate_triangle_generation(agent.actor, agent.tokenizer, envs, device=device, num_samples=NUM_EVAL_EPISODES)
    
    # Save fine-tuned model
    os.makedirs(os.path.dirname(FINE_TUNED_MODEL_PATH), exist_ok=True)
    save_model(agent.actor, agent.tokenizer, FINE_TUNED_MODEL_PATH.replace('ppo', 'reinforce_ucb'))
    print(f"Saved fine-tuned model to {FINE_TUNED_MODEL_PATH.replace('ppo', 'reinforce_ucb')}")
    
    # Log model artifact to wandb
    if wandb_run is not None:
        wandb_model_path = os.path.join(wandb_run.dir, "reinforce_ucb_triangle_discovery_model.pt")
        save_model(agent.actor, agent.tokenizer, wandb_model_path)
        print(f"Saved model to wandb directory: {wandb_model_path}")

        try:
            artifact = wandb.Artifact("reinforce_ucb_triangle_discovery_model", type="model")
            artifact.add_file(FINE_TUNED_MODEL_PATH.replace('ppo', 'reinforce_ucb'))
            wandb.log_artifact(artifact)
        except Exception as e:
            print(f"[wandb] log_artifact failed: {e}")
        try:
            wandb.finish()
        except Exception:
            pass

if __name__ == "__main__":
    main()
