import av
import hydra
import argparse
from hydra import compose, initialize
from tqdm import tqdm
from pathlib import Path
from typing import Dict, Tuple, Optional
from mcu.arm.models import ConditionedAgent
from mcu.steveI.agents import SteveITextAgent
from mcu.stark_tech.env_interface import MinecraftWrapper
import pdb
from functools import partial

POLICY_CONFIG_DIR = "arm/configs/policy"

def write(frames, path):
    container = av.open(path, mode='w', format='mp4')
    stream = container.add_stream('h264', rate=20)
    stream.width = 640 
    stream.height = 360
    stream.pix_fmt = 'yuv420p'
    for frame in frames:
        frame = av.VideoFrame.from_ndarray(frame, format='rgb24')
        for packet in stream.encode(frame):
            container.mux(packet)
    for packet in stream.encode():
        container.mux(packet)
    container.close()
    print("save video succeed!")

def get_config_from_yaml(config_name: str):
    hydra.core.global_hydra.GlobalHydra.instance().clear()
    config_path = Path(POLICY_CONFIG_DIR) / f"{config_name}.yaml"
    initialize(config_path=str(config_path.parent), version_base='1.3')
    config = compose(config_name=config_path.stem)
    print(config)
    return config

def run(
    env: str, 
    level,
    policy_config: str,
    agent_creator, 
    resolution: Optional[Tuple[int, int]] = None,
):
    frames = []
    # pdb.set_trace()
    agent = agent_creator(
        obs_space=MinecraftWrapper.get_obs_space(),
        action_space=MinecraftWrapper.get_action_space()
    ).cuda()
    agent.eval()
    state = agent.initial_state()
    # env is the config name
    env = MinecraftWrapper(env, level, prev_action_obs=True)
    if resolution is not None:
        env.resize_resolution = resolution
    obs, info = env.reset()
    idx = 0
    bar = tqdm()
    terminated, truncated = False, False
    # not for steve
    success = 0
    while (not terminated and not truncated):
        # pdb.set_trace()
        bar.set_description(f"Frame {idx}")
        action, state = agent.get_action(obs, state, first=None, input_shape="*")
        obs, reward, terminated, truncated, info = env.step(action)
        if reward>0:
            success = 1
        frames.append(info['pov'])
        idx += 1
    
    print('\n'+'Task success: ' + str(success))
    env.close()
    write(frames, "inference_video.mp4")

def run_agent_with_text_instruction(level):
    policy_config = get_config_from_yaml("steve")
    run("mine_grass", 
        level,
        policy_config, 
        agent_creator=partial(SteveITextAgent, policy_config=policy_config),
        resolution=(128, 128))

def run_agent(level):
    policy_config = get_config_from_yaml("vpt_native")
    run("eat_apple", 
        level, 
        policy_config, 
        agent_creator=partial(ConditionedAgent, policy_config=policy_config),
        resolution=(128, 128))



def run_agent_with_video_instruction(level):

    policy_config = get_config_from_yaml("groot_eff_1x")
    run("eat_apple", level, policy_config, \
    agent_creator=partial(ConditionedAgent, policy_config=policy_config), \
    resolution=(224, 224))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--test', type=str, default='run_agent')
    parser.add_argument('--level', type=str, default='simple')
    args = parser.parse_args()
    if args.test == 'run_agent':
        run_agent(args.level)
    elif args.test == 'run_agent_video':
        run_agent_with_video_instruction(args.level)
    elif args.test == 'run_agent_text':
        run_agent_with_text_instruction(args.level)
    