# fractal_sampling.py
import os
import pickle
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F

from pretrain import MLPPolicy
from ..utils import prepare_obs, get_model_paths, HIDDEN_DIM
from ..train_utils import setup_environment as setup_environment_multiseed  

def obs_to_key(obs: Dict[str, Any]) -> Tuple:
    """
    Convert an observation to a hashable key.
    Includes both the flattened image and the mission.
    """
    image_flat = tuple(np.asarray(obs["image"]).flatten().tolist())
    mission = obs.get("mission", "")
    return (image_flat, mission)


def _shape_success_reward(terminated: bool, steps: int, max_steps: int) -> float:
    """
    Reward shaping used in PPO/fractal: positive only on success termination.
    """
    if terminated:
        return 1.0 - 0.5 * (steps / float(max_steps))
    return 0.0


def _obs_to_feature(prepped: Dict[str, torch.Tensor]) -> torch.Tensor:
    """
    Convert prepared obs to a single feature row:
    - Flattened image (C,H,W) -> (C*H*W)
    - Direction one-hot (size 4)
    """
    img: torch.Tensor = prepped["image"]          
    x_img = img.reshape(1, -1).float()            

    dir_tensor: torch.Tensor = prepped.get("direction")
    if dir_tensor is None:
        dir_tensor = torch.zeros((1,), dtype=torch.long, device=img.device)
    dir_tensor = dir_tensor.clamp(0, 3)
    dir_oh = F.one_hot(dir_tensor, num_classes=4).float()  

    return torch.cat([x_img, dir_oh], dim=1)    

def _batch_to_features(
    obs_list: List[Dict[str, Any]],
    device: torch.device,
    use_text: bool = False, 
) -> torch.Tensor:
    feats = []
    for obs in obs_list:
        prepped = prepare_obs(obs, device=device, use_text=use_text)
        feats.append(_obs_to_feature(prepped))
    return torch.cat(feats, dim=0)  

@torch.no_grad()
def generate_rollout(
    policy: MLPPolicy,
    env,
    max_steps: int,
    device: torch.device,
    seed: Optional[int] = None,
    replay_actions: Optional[List[int]] = None,
    main_trunk: bool = False,  
    temperature: float = 1.0,
    use_text: bool = False,    
) -> List[Dict[str, Any]]:
    """
    Generate a trajectory starting either from a fresh reset (with an optional seed),
    or from the exact environment state reconstructed by:
        env.reset(seed=seed) + replay(prefix actions)

    The replayed prefix steps are NOT included in the returned trajectory; the returned
    transitions start at the branch point state (i.e., after the replay).
    """
    if seed is not None:
        obs, _ = env.reset(seed=seed)
    else:
        obs, _ = env.reset()

    mission0 = obs.get("mission", None)

    if replay_actions:
        for a in replay_actions:
            obs, _, terminated, truncated, _ = env.step(int(a))
            if terminated or truncated:
                return []
            if mission0 is not None:
                if obs.get("mission", mission0) != mission0:
                    return []
    else:
        if not main_trunk:
            return []

    policy.eval()
    trajectory: List[Dict[str, Any]] = []
    steps = 0
    done = False

    while (not done) and steps < max_steps:
        try:
            agent_pos = tuple(env.unwrapped.agent_pos)  
            agent_dir = int(env.unwrapped.agent_dir)
        except Exception:
            agent_pos, agent_dir = (-1, -1), -1  

        feats = _batch_to_features([obs], device=device, use_text=use_text)
        logits = policy(feats)
        if temperature > 0:
            logits = logits / max(1e-6, float(temperature))
        dist = torch.distributions.Categorical(logits=logits)
        action_t = dist.sample()
        action = int(action_t.item())
        log_prob = float(dist.log_prob(action_t).item())

        next_obs, env_r, terminated, truncated, info = env.step(action)

        if mission0 is not None:
            if next_obs.get("mission", mission0) != mission0:
                done = True
                break

        shaped_r = _shape_success_reward(terminated, steps, max_steps)
        if shaped_r < 0.0 or shaped_r > 1.0:
            shaped_r = max(0.0, min(1.0, shaped_r))

        done = bool(terminated or truncated)

        trajectory.append({
            "observation": obs,
            "action": action,
            "log_prob": log_prob,
            "reward": shaped_r,
            "next_observation": next_obs,
            "env_state": (agent_pos, agent_dir),
            "done": done,
            "terminated": bool(terminated),
            "truncated": bool(truncated),
        })

        obs = next_obs
        steps += 1

    return trajectory



# Dataset generation
def generate_dataset(
    policy: MLPPolicy,
    env,
    num_vines_at_state: int,
    num_levels: int,
    main_rollout_max_steps: int,
    device: torch.device,
    temperature: float = 1.0,
    use_text: bool = False,  
) -> Dict[Tuple, List[List[Dict[str, Any]]]]:
    """
    Generates an on-policy dataset using vine (branching) sampling in the PROVIDED env.
    We DO NOT create a new base seed; we read the seed from the first reset of 'env'
    (via info['seed'] if present) and reuse it for deterministic reconstruction.
    """
    dataset: Dict[Tuple, List[List[Dict[str, Any]]]] = {}
    current_trajectories: List[List[Dict[str, Any]]] = []
    num_trajectories_collected = 0
 
    initial_obs, info = env.reset()
    initial_state_key = obs_to_key(initial_obs)

    base_seed = info.get("seed", None) if isinstance(info, dict) else None
    if base_seed is None:
        print("[fractal] WARNING: env.reset() did not provide 'seed'; exact reconstruction may be approximate.")

    trajectories_from_start: List[List[Dict[str, Any]]] = []
    for _ in range(num_vines_at_state):
        traj = generate_rollout(
            policy=policy,
            env=env,
            max_steps=main_rollout_max_steps,
            device=device,
            seed=base_seed,         
            replay_actions=None,
            main_trunk=True,
            temperature=temperature,
            use_text=use_text,
        )
        if traj:
            trajectories_from_start.append(traj)
            current_trajectories.append(traj)

    num_trajectories_collected += len(trajectories_from_start)
    if trajectories_from_start:
        dataset[initial_state_key] = trajectories_from_start
    print(f"[fractal] Generated {len(trajectories_from_start)} initial trajectories at t=0.")

    states_to_sample: Dict[Tuple, Dict[str, Any]] = {}
    for traj in current_trajectories:
        if not traj:
            continue
        L = len(traj)
        if L == 0:
            continue

        timesteps = [int(k * (L / float(num_levels))) for k in range(1, num_levels)]
        timesteps = sorted(set(t for t in timesteps if 1 <= t < L))

        actions_seq = [step["action"] for step in traj]

        for t in timesteps:
            state_detail = traj[t]
            obs_t = state_detail["observation"]        
            state_key = obs_to_key(obs_t)
            mission = obs_t.get("mission", None)
            if mission is None:
                pass

            states_to_sample[(state_key, t)] = {
                "seed": base_seed,                                 
                "prefix_actions": actions_seq[:t],                  
                "prefix_transitions": traj[: t + 1],                 
                "remaining_steps": max(0, main_rollout_max_steps - t),
            }

    print(f"[fractal] Found {len(states_to_sample)} branch points to sample from.")

    for (state_key, t), meta in states_to_sample.items():
        seed = meta["seed"]
        prefix_actions = list(meta["prefix_actions"])
        original_prefix = list(meta["prefix_transitions"])
        remaining_steps = int(meta["remaining_steps"])

        trajectories_from_vine_state: List[List[Dict[str, Any]]] = []

        if prefix_actions:
            for _ in range(num_vines_at_state):
                vine_segment = generate_rollout(
                    policy=policy,
                    env=env,
                    max_steps=remaining_steps,
                    device=device,
                    seed=seed,
                    replay_actions=prefix_actions,
                    main_trunk=False,
                    temperature=temperature,
                    use_text=use_text,
                )
                if not vine_segment:
                    continue

                num_trajectories_collected += 1
                full_new_trajectory = vine_segment  
                trajectories_from_vine_state.append(full_new_trajectory)

            if trajectories_from_vine_state:
                vine_env_state = original_prefix[-1]['env_state']
                img_ref = np.array(original_prefix[-1]['observation']['image'])
                mission_ref = original_prefix[-1]['observation'].get('mission', None)

                for traj in trajectories_from_vine_state:
                    assert traj, "Empty vine trajectory encountered."
                    assert traj[0]['env_state'] == vine_env_state, "Initial env_state mismatch across vines."
                    assert np.array_equal(traj[0]['observation']['image'], img_ref), "Initial image mismatch."
                    if mission_ref is not None:
                        assert traj[0]['observation'].get('mission', mission_ref) == mission_ref, "Mission mismatch."

            if trajectories_from_vine_state:
                dataset.setdefault(state_key, []).extend(trajectories_from_vine_state)
        else:
            continue

    print(f"[fractal] Total number of trajectories collected: {num_trajectories_collected}")
    return dataset


def main():
    NUM_VINES_AT_STATE = 8
    NUM_LEVELS = 4
    MAIN_ROLLOUT_LENGTH = 50
    GAME_CODE = "minigrid"
    SEED = 1234

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    env = setup_environment_multiseed(GAME_CODE, max_steps=MAIN_ROLLOUT_LENGTH, meta_fixed_seed=SEED)

    obs0, _ = env.reset()
    feats0 = _batch_to_features([obs0], device, use_text=False)
    input_dim = int(feats0.shape[1])

    n_actions = env.action_space.n
    policy = MLPPolicy(input_dim=input_dim, output_dim=n_actions, hidden_dim=HIDDEN_DIM).to(device)

    try:
        model_paths = get_model_paths(GAME_CODE)
        model_path = model_paths.get("pretrain", "")
        if model_path and os.path.exists(model_path):
            print(f"Loading pretrained policy from: {model_path}")
            state = torch.load(model_path, map_location=device)
            policy.load_state_dict(state, strict=False)
        else:
            print(f"Pretrained model not found at {model_path}. Using random init.")
    except Exception as e:
        print(f"Pretrained load skipped ({e}). Using random init.")

    on_policy_dataset = generate_dataset(
        policy=policy,
        env=env,
        num_vines_at_state=NUM_VINES_AT_STATE,
        num_levels=NUM_LEVELS,
        main_rollout_max_steps=MAIN_ROLLOUT_LENGTH,
        device=device,
        temperature=1.0,
        use_text=False,
    )

    output_dir = os.path.join("tasks", "minigrid", GAME_CODE)
    os.makedirs(output_dir, exist_ok=True)
    output_file_name = os.path.join(output_dir, "fractal_on_policy_dataset.pkl")
    with open(output_file_name, "wb") as f:
        pickle.dump(on_policy_dataset, f)
    print(f"\nDataset saved to {output_file_name}")

    num_unique_states = len(on_policy_dataset)
    total_trajectories = sum(len(trajs) for trajs in on_policy_dataset.values())
    print(f"Number of unique vine points: {num_unique_states}")
    print(f"Total number of trajectories in dataset: {total_trajectories}")

    env.close()


if __name__ == "__main__":
    main()
