import os
import pickle
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import torch
import gymnasium as gym

from utils import TransformerPolicy, TriangleTokenizer, HASH_STR_LEN, MODEL_PATH, META_FIXED_SEED, make_env, load_model_and_tokenizer
from utils import FixedGraphTriangleEnvironment as TriangleEnvironment

DEBUG = os.environ.get("DEBUG", "False").lower() == "true"

# -------------------------------------------------------
# Keys / helpers
# -------------------------------------------------------
def obs_to_key(obs: np.ndarray) -> Tuple:
    """
    Convert a triangle discovery observation (sequence) to a hashable key.
    The observation is a numpy array of token IDs.
    """
    return tuple(obs.tolist())

# -------------------------------------------------------
# Rollout with optional reset+replay (canonical vine sampling)
# -------------------------------------------------------
def prepare_obs(tokenizer, obs, device):
    # Extract only the non-padded part of the observation (same as generate function)
    pad_id = tokenizer.pad_id

    actual_length = (np.array(obs) != pad_id).sum()
    actual_sequence = obs[:actual_length]
    
    obs_tensor = torch.tensor(actual_sequence, dtype=torch.long).unsqueeze(0).to(device)
    return obs_tensor
        
@torch.no_grad()
def select_action(policy, obs, tokenizer, temperature=1.0, device="cuda"): 
    obs_tensor = prepare_obs(tokenizer, obs, device)
    action_logits = policy(obs_tensor)
    next_token_logits = action_logits[:, -1, :]
    
    if temperature == 0.0:
        action = torch.argmax(next_token_logits, dim=-1)
        log_prob = torch.log(torch.softmax(next_token_logits, dim=-1)[0, action])
    else:
        scaled_logits = next_token_logits / temperature
        dist = torch.distributions.Categorical(logits=scaled_logits)
        action = dist.sample()
        log_prob = dist.log_prob(action)
    return action.item(), log_prob.item(), dist.entropy().item()
    
@torch.no_grad()
def generate_rollout(
    policy: TransformerPolicy,
    tokenizer: TriangleTokenizer,
    env: TriangleEnvironment,
    max_steps: int,
    device: torch.device,
    seed: Optional[int] = None,
    replay_actions: Optional[List[int]] = None,
    main_trunk: bool = False, # Boolean, if True, we are reconstructing the exact vine state
) -> List[Dict[str, Any]]:
    """
    Generate a trajectory starting either from a fresh reset (with an optional seed),
    or from the exact environment state reconstructed by:
        env.reset(seed=seed) + replay(prefix actions)

    The replayed prefix steps are NOT included in the returned trajectory; the returned
    transitions start at the branch point state (i.e., after the replay).
    """
    # 1) Reset (respect the provided seed if any)
    if seed is not None:
        obs, _ = env.reset(seed=seed)
    else:
        print(f"NO SEED PROVIDED - WE ARE NOT RECONSTRUCTING THE EXACT VINE STATE!!!")
        obs, _ = env.reset()
    
    # 2) Replay prefix to reconstruct exact branch state (if provided)
    if replay_actions:
        for a in replay_actions:
            obs, _, terminated, truncated, _ = env.step(int(a))
            if terminated or truncated:
                # The original prefix ended the episode—cannot branch beyond that.
                return []
    else:
        if not main_trunk:
            print(f"NO REPLAY ACTIONS PROVIDED - WE ARE NOT RECONSTRUCTING THE EXACT VINE STATE!!!")
            print(f"we should check whether seed is none - if so, we are not reconstructing the exact vine state")
    
    # 3) Roll from the branch point
    policy.eval()
    trajectory: List[Dict[str, Any]] = []
    steps = 0
    done = False

    while not done and steps < max_steps:
        # Select action
        action, log_prob, _ = select_action(policy, obs, tokenizer, device=device, temperature=1.0) # NOTE: set to 0 for debug
        obs_tensor = torch.tensor(obs, dtype=torch.long).unsqueeze(0).to(device)
        # Step
        next_obs, reward, terminated, truncated, info = env.step(action)

        done = bool(terminated or truncated)

        trajectory.append({
            "observation": obs.copy(),
            "action": action,
            "log_prob": log_prob,
            "reward": reward,
            "next_observation": next_obs.copy(),
            "env_state": env.env.current_sequence[:-1].copy(), # NOTE: env_state includes the action, don't include that so use -1
            "graph_idx": env.graph_idx,
            "done": done,
        })

        obs = next_obs
        steps += 1

    return trajectory


# -------------------------------------------------------
# Dataset generation (vine/fractal sampling)
# NOTE: signature and return type remain EXACTLY the same.
# -------------------------------------------------------
def generate_dataset(
    policy: TransformerPolicy,
    tokenizer: TriangleTokenizer,
    env: TriangleEnvironment,
    num_vines_at_state: int,
    num_levels: int,
    main_rollout_max_steps: int,
    device: torch.device,
) -> Dict[Tuple, List[List[Dict[str, Any]]]]:
    """
    Generates an on-policy dataset using vine (branching) sampling in the PROVIDED env.
    We generate different base seed for each rollout and reconstruct it using the same seed.
    """
    dataset: Dict[Tuple, List[List[Dict[str, Any]]]] = {}
    current_trajectories: List[List[Dict[str, Any]]] = []
    num_trajectories_collected = 0

    # ---------- Step 1: Initial reset of the PROVIDED env ----------
    print("--- Generating initial set of trajectories (vine at t=0) ---")
    initial_obs, info = env.reset()
    initial_state_key = obs_to_key(initial_obs)

    # Try to capture the seed used by this env reset for exact reconstruction.
    base_seed = info.get("seed", None) if isinstance(info, dict) else None
    if base_seed is None:
        print("[fractal] WARNING: env.reset() did not provide a 'seed' in info; "
              "exact reconstruction across resets may not be guaranteed.")

    # Generate initial trajectories; if base_seed is known, reset with that each time
    trajectories_from_start: List[List[Dict[str, Any]]] = []
    for i in range(num_vines_at_state):
        # print(f"  Generating initial trajectory {i+1}/{num_vines_at_state}...")
        traj = generate_rollout(
            policy=policy,
            tokenizer=tokenizer,
            env=env,
            max_steps=main_rollout_max_steps,
            device=device,
            seed=base_seed,          # Use the seed from the first reset (if available)
            replay_actions=None,
            main_trunk=True,
        )
        if traj:
            trajectories_from_start.append(traj)
            current_trajectories.append(traj)

    num_trajectories_collected += len(trajectories_from_start)
    if trajectories_from_start:
        dataset[initial_state_key] = trajectories_from_start
    print(f"Generated {len(trajectories_from_start)} initial trajectories.")

    # check that all 8 trajectories start from the same state
    for traj in trajectories_from_start:
        assert traj[0]['observation'].tolist() == list(initial_state_key), f"Initial state should be the same. Expected: {initial_state_key}, Got: {traj[0]['observation'].tolist()}"

    # ---------- Step 2: Identify branch points (levels within each trajectory) ----------
    states_to_sample: Dict[Tuple, Dict[str, Any]] = {}

    for i, traj in enumerate(current_trajectories):
        if not traj:
            continue
        L = len(traj)
        if L == 0:
            continue

        # Evenly spaced interior points
        timesteps = [int(k * (L / float(num_levels))) for k in range(1, num_levels)]
        timesteps = [t for t in timesteps if 0 <= t < L]  # defensive bound
        timesteps = [t-1 for t in timesteps]    # NOTE: start generating branch at t instead of t+1 for this env
        if DEBUG:
            print(f"Timesteps for branching (len={L}): {timesteps}")
            # if L != 9:
                # breakpoint()
                # continue

        # Build the action prefix for exact reconstruction
        actions_seq = [step["action"] for step in traj]

        # Reconstruct using the SAME base_seed captured at the start
        for t in timesteps:
            state_detail = traj[t]                      # transition at index t
            obs_t = state_detail["observation"]         # state BEFORE taking action[t]
            state_key = obs_to_key(obs_t)               # (keep your original key form)

            states_to_sample[(state_key, t)] = {
                "seed": base_seed,                                  # seed of the episode (may be None)
                "prefix_actions": actions_seq[:t],                   # replay up to s_t
                "prefix_transitions": traj[: t + 1],                 # your original prefix (includes transition t)
                "remaining_steps": max(0, main_rollout_max_steps - t),
            }

    print(f"Found {len(states_to_sample)} branch points to sample from.")

    # ---------- Step 3: Generate vines from each branch point ----------
    newly_generated_trajectories: List[List[Dict[str, Any]]] = []

    for (state_key, t), meta in states_to_sample.items():
        seed = meta["seed"]  # can be None if env didn't provide it
        prefix_actions = list(meta["prefix_actions"])
        original_prefix = list(meta["prefix_transitions"])
        remaining_steps = int(meta["remaining_steps"])

        trajectories_from_vine_state: List[List[Dict[str, Any]]] = []
        for _ in range(num_vines_at_state):
            vine_segment = generate_rollout(
                policy=policy,
                tokenizer=tokenizer,
                env=env,
                max_steps=remaining_steps,
                device=device,
                seed=seed,                 # reuse the same seed from the very first reset
                replay_actions=prefix_actions,
                main_trunk=False,
            )
            if not vine_segment:
                continue
            
            num_trajectories_collected += 1
            full_new_trajectory = vine_segment
            trajectories_from_vine_state.append(full_new_trajectory)
            newly_generated_trajectories.append(full_new_trajectory)

        # ensure that all trajectories in trajectories_from_vine_state start from the same state
        vine_env_state = original_prefix[-1]['env_state']
        for traj in trajectories_from_vine_state:
            assert traj[0]['env_state'] == vine_env_state, f"Initial state should be the same. Expected: {vine_env_state}, Got: {traj[0]['env_state']}"
            assert np.array_equal(traj[0]['observation'], original_prefix[-1]['observation']), "Initial observation should be the same"

        if trajectories_from_vine_state:
            dataset.setdefault(state_key, []).extend(trajectories_from_vine_state)
    
    # Check that the entire dataset is consistent keys => observations (includes initial states)
    for vine_state_key, trajectories in dataset.items():
        for t in trajectories:
            assert t[0]['observation'].tolist() == list(vine_state_key), f"Initial state should be the same. Expected: {vine_state_key}, Got: {t[0]['observation'].tolist()}" 


    print(f"Total number of trajectories collected: {num_trajectories_collected}")
    return dataset


# -------------------------------------------------------
# Script entry (left for convenience; uses the provided env in generate_dataset)
# -------------------------------------------------------
def main():
    # --- Hyperparameters ---
    NUM_VINES_AT_STATE = 8      # Number of new trajectories to generate from each sampled state
    NUM_LEVELS = 3              # Number of segments to split trajectories into (branch at 1..NUM_LEVELS-1)
    # NOTE: instead of length 16 use 11, valid length of sequence
    MAIN_ROLLOUT_LENGTH = 9    # Maximum length of any single vine segment
    
    # --- Setup ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load tokenizer from model file
    if os.path.exists(MODEL_PATH):
        print(f"Loaded pretrained policy from {MODEL_PATH}")

        policy, tokenizer = load_model_and_tokenizer(MODEL_PATH, device=device)
        print(f"Loaded tokenizer with vocab_size={tokenizer.vocab_size}")
    else:
        raise FileNotFoundError(f"Could not find {MODEL_PATH}")
    

    # Create environment and agent
    envs = []
    datasets = []
    for i in range(3):
        env = make_env(tokenizer, i, device)
        envs.append(env)

        # --- Generate the Dataset (in the provided env) ---
        on_policy_dataset = generate_dataset(
            policy=policy,
            tokenizer=tokenizer,
            env=env,
            num_vines_at_state=NUM_VINES_AT_STATE,
            num_levels=NUM_LEVELS,
            main_rollout_max_steps=MAIN_ROLLOUT_LENGTH,
            device=device,
        )
        datasets.append(on_policy_dataset)

    # --- Save the Dataset ---
    output_dir = "rollthedice/triangle_discovery"
    os.makedirs(output_dir, exist_ok=True)
    output_file_name = os.path.join(output_dir, "fractal_on_policy_dataset.pkl")
    with open(output_file_name, "wb") as f:
        pickle.dump(on_policy_dataset, f)
    print(f"\nDataset saved to {output_file_name}")

    # --- Print Summary ---
    num_unique_states = len(on_policy_dataset)
    total_trajectories = sum(len(trajs) for trajs in on_policy_dataset.values())
    print(f"Number of unique vine points: {num_unique_states}")
    print(f"Total number of trajectories in dataset: {total_trajectories}")

    env.close()


if __name__ == "__main__":
    main()
