from typing import Dict
import numpy as np
import torch

from achievement_distillation.wrapper import VecPyTorch
from achievement_distillation.storage import RolloutStorage
from achievement_distillation.model.base import BaseModel
from utils import ObservationToOption, text_obs, OptionToAction, gen_prompt, check_task_done

def sample_rollouts(
    venv: VecPyTorch,
    model: BaseModel,
    storage: RolloutStorage,
    lambda_t: float,
    lambda_decay: float,
    imitation_phase: bool,
    task: str,
    video_writers=None
) -> Dict[str, np.ndarray]:
    # Set model to eval mode
    model.eval()

    # Sample rollouts
    episode_rewards = []
    episode_lengths = []
    achievements = []
    successes = []
    total_tokens = 0  # Initialize cumulative tokens counter

    for step in range(storage.nstep):
        # Get model output (π_θ)
        inputs = storage.get_inputs(step)
        outputs = model.act(**inputs)
        actions_theta = outputs["actions"]

        # Initialize tokens outside of imitation_phase check to avoid undefined variable
        tokens = 0

        # If in imitation phase, combine π_meta with π_θ
        if imitation_phase:
            pi_meta_logits, tokens_step = get_pi_meta(inputs["obs"], venv, task)
            pi_theta_logits = outputs["pi_logits"]
            pi_meta_logits = pi_meta_logits.to(pi_theta_logits.device)

            # Combine the π_meta and π_θ logits
            pi_combined_logits = lambda_t * pi_meta_logits
            pi_combined_probs = torch.softmax(pi_combined_logits, dim=-1)
            actions = torch.argmax(pi_combined_probs, dim=-1).unsqueeze(-1)
            v_meta = model.vf_head(outputs["latents"]).detach()
            outputs["pi_meta"] = pi_meta_logits
            outputs["v_meta"] = v_meta

            # Accumulate tokens from this step
            tokens = tokens_step
        else:
            actions = actions_theta

        # Step the environment
        obs, rewards, dones, infos = venv.step(actions)

        # Save each frame if video_writers are provided
        if video_writers:
            for env_id, writer in enumerate(video_writers):
                # Use crafter's render method to capture each frame
                frame = venv.envs[env_id].render()  # Assuming the render() returns RGB frames
                writer.append_data(frame)  # Append the frame to the video writer

        # Check if task is done using the new `check_task_done` function
        for i, env_info in enumerate(infos):
            if check_task_done(venv.envs[i], task):
                dones[i] = True  # Mark done if the task is completed

        # Update cumulative tokens
        total_tokens += tokens  # Accumulate tokens for all steps

        outputs["obs"] = obs
        outputs["rewards"] = rewards
        outputs["masks"] = 1.0 - dones
        outputs["successes"] = infos["successes"]
        outputs["total_tokens"] = tokens  # Record tokens for this step

        
        # Update storage
        storage.insert(**outputs, model=model)

        # Update stats
        for i, done in enumerate(dones):
            if done:
                episode_lengths.append(infos["episode_lengths"][i].cpu().numpy())
                episode_rewards.append(infos["episode_rewards"][i].cpu().numpy())
                achievements.append(infos["achievements"][i].cpu().numpy())
                successes.append(infos["successes"][i].cpu().numpy())

    # Compute final state value
    inputs = storage.get_inputs(step=-1)
    outputs = model.act(**inputs)
    storage.vpreds[-1].copy_(outputs["vpreds"])

    # Stack stats
    episode_lengths = np.stack(episode_lengths, axis=0).astype(np.int32)
    episode_rewards = np.stack(episode_rewards, axis=0).astype(np.float32)
    achievements = np.stack(achievements, axis=0).astype(np.int32)
    successes = np.stack(successes, axis=0).astype(np.int32)

    # Define rollout stats and include total tokens consumed
    rollout_stats = {
        "episode_lengths": episode_lengths,
        "episode_rewards": episode_rewards,
        "achievements": achievements,
        "successes": successes,
        "total_tokens": total_tokens  # Add the cumulative token count to stats
    }

    return rollout_stats



def evaluate(
    venv: VecPyTorch,
    model: BaseModel,
    storage: RolloutStorage,
    task: str,
    video_writers=None
) -> Dict[str, np.ndarray]:
    # Set model to eval mode
    model.eval()

    # Sample rollouts
    episode_rewards = []
    episode_lengths = []
    achievements = []
    successes = []

    for step in range(storage.nstep):
        # Get model output (π_θ)
        inputs = storage.get_inputs(step)
        outputs = model.act(**inputs)
        actions = outputs["actions"]


        # Step the environment
        obs, rewards, dones, infos = venv.step(actions)

        # Save each frame if video_writers are provided
        if video_writers:
            for env_id, writer in enumerate(video_writers):
                # Use crafter's render method to capture each frame
                frame = venv.envs[env_id].render()  # Assuming the render() returns RGB frames
                writer.append_data(frame)  # Append the frame to the video writer

        # # Check if task is done using the new `check_task_done` function
        # for i, env_info in enumerate(infos):
        #     if check_task_done(venv.envs[i], task):
        #         dones[i] = True  # Mark done if the task is completed
        
        outputs["obs"] = obs
        outputs["rewards"] = rewards
        outputs["masks"] = 1.0 - dones
        outputs["successes"] = infos["successes"]
        
        
        
        # Update storage
        storage.insert(**outputs, model=model)        
        

        # Update stats
        for i, done in enumerate(dones):
            if done:
                episode_lengths.append(infos["episode_lengths"][i].cpu().numpy())
                episode_rewards.append(infos["episode_rewards"][i].cpu().numpy())
                achievements.append(infos["achievements"][i].cpu().numpy())
                successes.append(infos["successes"][i].cpu().numpy())

    # Stack stats
    episode_lengths = np.stack(episode_lengths, axis=0).astype(np.int32)
    episode_rewards = np.stack(episode_rewards, axis=0).astype(np.float32)
    achievements = np.stack(achievements, axis=0).astype(np.int32)
    successes = np.stack(successes, axis=0).astype(np.int32)

    # Define rollout stats
    rollout_stats = {
        "episode_lengths": episode_lengths,
        "episode_rewards": episode_rewards,
        "achievements": achievements,
        "successes": successes,
    }

    return rollout_stats




def get_pi_meta(obs_batch, venv, task):
    """
    Function to generate π_meta logits (or meta actions) for a batch of observations.
    
    Parameters:
    - obs_batch: A batch of observations from parallel environments.
    - venv: The parallel environments wrapper (VecEnv).
    - task: The task (string) for the current environment.

    Returns:
    - pi_meta_logits: A tensor of logits (or actions) for each environment in the batch.
    """
    
    batch_size = obs_batch.shape[0]
    pi_meta_logits = []
    total_token = 0

    for i in range(batch_size):
        try:
            obs = obs_batch[i]
            env = venv.envs[i]
            facing_dir = env._direction
            env_info = get_info(env)
            
            # Call meta-controller for the current environment
            res = text_obs(env_info)
            obs_to_option = ObservationToOption()
            prompt = gen_prompt(res, task)
            option, token = obs_to_option.querry_LLM(prompt, task)
            
            total_token += token
            
            option_to_action = OptionToAction(env_info, facing_dir, mem={})
            meta_action = option_to_action.option_to_action(option)
        except:
            meta_action = np.random.randint(0, env.action_space.n)

        num_actions = venv.action_space.n
        logits = torch.zeros(num_actions)
        logits[meta_action] = 1.0
        pi_meta_logits.append(logits)

    pi_meta_logits = torch.stack(pi_meta_logits, dim=0)
    return pi_meta_logits, total_token

def get_info(env):
    info = {
        'inventory': env._player.inventory.copy(),
        'achievements': env._player.achievements.copy(),
        'semantic': env._sem_view(),
        'player_pos': env._player.pos,
    }
    return info
