# fractal_sampling.py
import os
import pickle
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import torch

from ..utils import BabyAI_BC, prepare_obs, get_model_paths
from ..open_utils import setup_environment
from minigrid.wrappers import FullyObsWrapper  

def obs_to_key(obs: Dict[str, Any]) -> Tuple:
    """
    Convert an observation to a hashable key.
    Includes both the flattened image and the mission.
    """
    image_flat = tuple(np.asarray(obs["image"]).flatten().tolist())
    mission = obs.get("mission", "")
    return (image_flat, mission)


def _shape_success_reward(terminated: bool, steps: int, max_steps: int) -> float:
    """
    Reward shaping used in PPO/fractal: positive only on success termination.
    """
    if terminated:
        return 1.0 - 0.5 * (steps / float(max_steps))
    return 0.0


# Rollout with reset+replay 

@torch.no_grad()
def generate_rollout(
    policy: BabyAI_BC,
    env,
    max_steps: int,
    device: torch.device,
    seed: Optional[int] = None,
    replay_actions: Optional[List[int]] = None,
    main_trunk: bool = False, 
) -> List[Dict[str, Any]]:
    """
    Generate a trajectory starting either from a fresh reset (with an optional seed),
    or from the exact environment state reconstructed by:
        env.reset(seed=seed) + replay(prefix actions)

    The replayed prefix steps are NOT included in the returned trajectory; the returned
    transitions start at the branch point state (i.e., after the replay).
    """
    # Reset 
    if seed is not None:
        obs, _ = env.reset(seed=seed)
    else:
        obs, _ = env.reset()

    assert "mission" in obs, "Mission should be in the observation"
    mission0 = obs["mission"]

    # Replay prefix to reconstruct exact branch state 
    if replay_actions:
        for a in replay_actions:
            obs, _, terminated, truncated, _ = env.step(int(a))
            if terminated or truncated:
                return []
            assert "mission" in obs and obs["mission"] == mission0, "Mission should remain the same during replay"
    else:
        if not main_trunk:
            print(f"NO REPLAY ACTIONS PROVIDED - WE ARE NOT RECONSTRUCTING THE EXACT VINE STATE!!!")
            print(f"Seed: {seed}, State Key: {state_key}, time: {t}, prefix actions: {prefix_actions}, remaining steps: {remaining_steps}")
            return []
    
    policy.eval()
    trajectory: List[Dict[str, Any]] = []
    steps = 0
    done = False

    while (not done) and steps < max_steps:
        try:
            agent_pos = tuple(env.unwrapped.agent_pos)  # (x, y)
            agent_dir = int(env.unwrapped.agent_dir)
        except Exception:
            print(f"ENV DOES NOT EXPOSE AGENT POS AND DIR - FALLBACK TO (-1, -1), -1!!! THIS WILL CAUSE ISSUES WITH DIVERSITY ESTIMATION!!!")

        # Model input
        batch = prepare_obs(obs, device=device, use_text=policy.use_text)
        logits = policy(batch)
        dist = torch.distributions.Categorical(logits=logits)
        action = int(dist.sample().item())
        log_prob = float(dist.log_prob(torch.tensor(action, device=device, dtype=torch.long)).item())

        # Step
        next_obs, env_r, terminated, truncated, info = env.step(action)
        shaped_r = _shape_success_reward(terminated, steps, max_steps)
        if shaped_r < 0.0 or shaped_r > 1.0:
            assert False, "Shaped reward should be between 0 and 1"

        done = bool(terminated or truncated)

        trajectory.append({
            "observation": obs,
            "action": action,
            "log_prob": log_prob,
            "reward": shaped_r,
            "next_observation": next_obs,
            "env_state": (agent_pos, agent_dir),
            "done": done,
            "terminated": bool(terminated),
            "truncated": bool(truncated),
        })

        obs = next_obs
        steps += 1

    return trajectory



# Dataset generation 
def generate_dataset(
    policy: BabyAI_BC,
    env,
    num_vines_at_state: int,
    num_levels: int,
    main_rollout_max_steps: int,
    device: torch.device,
) -> Dict[Tuple, List[List[Dict[str, Any]]]]:
    """
    Generates an on-policy dataset using vine (branching) sampling in the PROVIDED env.
    We DO NOT create a new base seed; we read the seed from the first reset of 'env'
    (via info['seed'] if present) and reuse it for deterministic reconstruction.
    """
    dataset: Dict[Tuple, List[List[Dict[str, Any]]]] = {}
    current_trajectories: List[List[Dict[str, Any]]] = []
    num_trajectories_collected = 0

    #  Initial reset 
    print("--- Generating initial set of trajectories (vine at t=0) ---")
    initial_obs, info = env.reset()
    initial_state_key = obs_to_key(initial_obs)

    base_seed = info.get("seed", None) if isinstance(info, dict) else None
    if base_seed is None:
        print("[fractal] WARNING: env.reset() did not provide a 'seed' in info; "
              "exact reconstruction across resets may not be guaranteed.")

    trajectories_from_start: List[List[Dict[str, Any]]] = []
    for i in range(num_vines_at_state):
        print(f"  Generating initial trajectory {i+1}/{num_vines_at_state}...")
        traj = generate_rollout(
            policy=policy,
            env=env,
            max_steps=main_rollout_max_steps,
            device=device,
            seed=base_seed,         
            replay_actions=None,
            main_trunk=True,
        )
        if traj:
            trajectories_from_start.append(traj)
            current_trajectories.append(traj)

    num_trajectories_collected += len(trajectories_from_start)
    if trajectories_from_start:
        dataset[initial_state_key] = trajectories_from_start
    print(f"Generated {len(trajectories_from_start)} initial trajectories.")

    states_to_sample: Dict[Tuple, Dict[str, Any]] = {}

    for traj in current_trajectories:
        if not traj:
            continue
        L = len(traj)
        if L == 0:
            continue

        timesteps = [int(k * (L / float(num_levels))) for k in range(1, num_levels)]
        timesteps = sorted(set(t for t in timesteps if 1 <= t < L))  
        print(f"Timesteps for branching (len={L}): {timesteps}")

        actions_seq = [step["action"] for step in traj]

        for t in timesteps:
            state_detail = traj[t]                 
            obs_t = state_detail["observation"]       
            state_key = obs_to_key(obs_t)              
            mission = obs_t.get("mission", None)
            assert mission is not None, "Mission should be in the observation"

            states_to_sample[(state_key, t)] = {
                "seed": base_seed,                                  
                "prefix_actions": actions_seq[:t],                   
                "prefix_transitions": traj[: t + 1],                 
                "remaining_steps": max(0, main_rollout_max_steps - t),
            }

    print(f"Found {len(states_to_sample)} branch points to sample from.")

    newly_generated_trajectories: List[List[Dict[str, Any]]] = []
    for (state_key, t), meta in states_to_sample.items():
        seed = meta["seed"] 
        prefix_actions = list(meta["prefix_actions"])
        original_prefix = list(meta["prefix_transitions"])
        remaining_steps = int(meta["remaining_steps"])

        if prefix_actions is not None:
            trajectories_from_vine_state: List[List[Dict[str, Any]]] = []
            for _ in range(num_vines_at_state):
                vine_segment = generate_rollout(
                    policy=policy,
                    env=env,
                    max_steps=remaining_steps,
                    device=device,
                    seed=seed,                
                    replay_actions=prefix_actions,
                    main_trunk=False,
                )
                if not vine_segment:
                    continue

                num_trajectories_collected += 1
                full_new_trajectory = vine_segment
                trajectories_from_vine_state.append(full_new_trajectory)
                newly_generated_trajectories.append(full_new_trajectory)

            vine_env_state = original_prefix[-1]['env_state']
            for traj in trajectories_from_vine_state:
                assert traj[0]['env_state'] == vine_env_state, "Initial state should be the same"
                assert np.array_equal(traj[0]['observation']['image'], original_prefix[-1]['observation']['image']), "Initial observation images should be the same"
                assert traj[0]['observation']['mission'] == original_prefix[-1]['observation']['mission'], "Mission should be the same"

        else:
            print(f"NO REPLAY ACTIONS PROVIDED, SO NO VINE STATE TO SAMPLE FROM. Seed: {seed}, State Key: {state_key}, time: {t}, prefix actions: {prefix_actions}, remaining steps: {remaining_steps}")

        if trajectories_from_vine_state:
            dataset.setdefault(state_key, []).extend(trajectories_from_vine_state)

    print(f"Total number of trajectories collected: {num_trajectories_collected}")
    return dataset

def main():
    NUM_VINES_AT_STATE = 8      
    NUM_LEVELS = 4              
    MAIN_ROLLOUT_LENGTH = 50    
    GAME_CODE = "open"        
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    
    env = setup_environment(max_steps=MAIN_ROLLOUT_LENGTH)

    n_actions = env.action_space.n
    policy = BabyAI_BC(n_actions=n_actions, use_text=False).to(device)

    model_paths = get_model_paths(GAME_CODE)
    model_path = model_paths["pretrain"]  

    if os.path.exists(model_path):
        print(f"Loading pretrained policy from: {model_path}")
        policy.load_state_dict(torch.load(model_path, map_location=device), strict=False)
    else:
        print(f"Warning: Pretrained model not found at {model_path}.")
        print("Running with a randomly initialized policy.")

    on_policy_dataset = generate_dataset(
        policy=policy,
        env=env,
        num_vines_at_state=NUM_VINES_AT_STATE,
        num_levels=NUM_LEVELS,
        main_rollout_max_steps=MAIN_ROLLOUT_LENGTH,
        device=device,
    )

    output_dir = os.path.join("tasks/minigrid", GAME_CODE)
    os.makedirs(output_dir, exist_ok=True)
    output_file_name = os.path.join(output_dir, "fractal_on_policy_dataset.pkl")
    with open(output_file_name, "wb") as f:
        pickle.dump(on_policy_dataset, f)
    print(f"\nDataset saved to {output_file_name}")

    num_unique_states = len(on_policy_dataset)
    total_trajectories = sum(len(trajs) for trajs in on_policy_dataset.values())
    print(f"Number of unique vine points: {num_unique_states}")
    print(f"Total number of trajectories in dataset: {total_trajectories}")

    env.close()


if __name__ == "__main__":
    main()
