# import packages

import os
import torch
import h5py
import gymnasium
import mediapy as media
from huggingface_hub import hf_hub_download

from humenv import STANDARD_TASKS, make_humenv
from metamotivo.fb_cpr.huggingface import FBcprModel
from metamotivo.wrappers.humenvbench import RewardWrapper
from metamotivo.buffers.buffers import DictBuffer

from metamotivo.wrappers.humenvbench import relabel
from humenv.env import make_from_name
import matplotlib.pyplot as plt 
import numpy as np
from PIL import Image
import cv2

os.environ["OMP_NUM_THREADS"] = "1"

def get_data(task, model):

    is_folder = os.path.isdir(task)
    if is_folder == False: raise Exception(f'No folder: {task}')

    data = np.load(f'{task}/data.npy')
    
    next_qpos = np.load(f'{task}/qpos_list.npy')
    next_qvel = np.load(f'{task}/qvel_list.npy')
    action_data = np.load(f'{task}/action_list.npy')

    fragment = np.load(f'{task}/fragment.npy')

    reward_fn = make_from_name(task) 
    env, _ = make_humenv(num_envs=1, task=task, state_init="Default",
                            wrappers=[gymnasium.wrappers.FlattenObservation])

    rewards = relabel(env, qpos=next_qpos, qvel=next_qvel, action=action_data, reward_fn=reward_fn, max_workers=4)

    z = model.reward_wr_inference(
        next_obs=torch.tensor(data, device=model.cfg.device, dtype=torch.float32),
        reward=torch.tensor(rewards, device=model.cfg.device, dtype=torch.float32)
    )
    
    return z, fragment

def get_combined_z(task_1_name, task_2_name, tast_1_ratio, task_2_ratio):
    
    task_1_z, fragent = get_data(task_1_name)
    task_2_z, fragent = get_data(task_2_name)
    
    return (task_1_z * tast_1_ratio) + (task_2_z * task_2_ratio)

def produce_frames(env, z, rew_model, frames_count=300):
    
    observation, _ = env.reset()
    frames = [env.render()]
    
    for i in range(frames_count):
        obs = torch.tensor(observation.reshape(1, -1), dtype=torch.float32, device=rew_model.device)
        action = model.act(obs, z, mean=True).ravel()
        observation, reward, terminated, truncated, info = env.step(action.cpu().numpy().ravel())
        frames.append(env.render())

    return frames

def get_combined_zs(multi_tasks, multi_scores, model):
    
    combined_z = torch.zeros([1, 256])
    
    for task, score in zip(multi_tasks, multi_scores):
        task_z, _ = get_data(task, model)
        combined_z += (task_z * score)
    
    return combined_z # / len(multi_tasks)

def merge_frames(frames_1, frames_2, frames_3):
    new_combine_frames = []

    for f1, f2, f3 in zip(frames_1, frames_2, frames_3):
        new_frame = np.concatenate((f1, f2, f3), axis=1)
        
        img = Image.fromarray(new_frame)
        img_resized = img.resize((960, 240))
        new_frame = np.array(img_resized)
        
        new_combine_frames.append(new_frame)
        
    return new_combine_frames
  
def add_text_to_frames(frames, text):
    output_frames = []
    font = cv2.FONT_HERSHEY_SIMPLEX
    for frame in frames:
        frame_with_text = frame.copy()
        cv2.putText(frame_with_text, text, (10, 40), font, 1.2, (255, 255, 255), 3, cv2.LINE_AA)
        output_frames.append(frame_with_text)
    return output_frames

def save_interval_frames(new_combine_frames, task_1, task_2):
    
    frames_length = len(new_combine_frames)
    interval = 15
    interval_length = frames_length // interval

    for i in range(interval_length):
        start_frame = i * interval
        interval_frame = new_combine_frames[start_frame]
        
        rgb_frame = cv2.cvtColor(interval_frame, cv2.COLOR_BGR2RGB)
        cv2.imwrite(f'images/{task_1}_{task_2}_{i}.png', rgb_frame)

model = FBcprModel.from_pretrained("facebook/metamotivo-S-1", device="cpu")

local_dir = "metamotivo-S-1-datasets"
dataset = "buffer_inference_500000.hdf5"
buffer_path = hf_hub_download(
    repo_id="facebook/metamotivo-S-1",
    filename=f"data/{dataset}",
    repo_type="model",
    local_dir=local_dir)

hf = h5py.File(buffer_path, "r")
data = {k: v[:] for k, v in hf.items()}
buffer = DictBuffer(capacity=data["qpos"].shape[0], device="cpu")
buffer.extend(data)
rew_model = RewardWrapper(
    model=model,
    inference_dataset=buffer,
    num_samples_per_inference=100_000,
    inference_function="reward_wr_inference",
    max_workers=40,
    process_executor=True,
    process_context="forkserver")

task_1 = 'move-ego-0-2'
task_2 = 'sitonground' 

env, _ = make_humenv(num_envs=1, task=task_2, state_init="DefaultAndFall",
                        wrappers=[gymnasium.wrappers.FlattenObservation])

task_1_ratio = 1.0
task_2_ratio = 2.0

# Get frames

z = get_combined_zs([task_1, task_2], [task_1_ratio, 0.0], model)
frames_1 = produce_frames(env, z, rew_model, 100)
# media.show_video(frames, fps=30)

z = get_combined_zs([task_1, task_2], [0.0, task_2_ratio], model)
frames_2 = produce_frames(env, z, rew_model, 100)

z = get_combined_zs([task_1, task_2], [task_1_ratio, -1.0 * task_2_ratio], model)
frames_3 = produce_frames(env, z, rew_model, 100)
    
file_name = 'display_frames.png'

############################
### Set the cliped range ### 
############################

x_range = (150, 450)
y_range = (225, 475)

x_length = x_range[1] - x_range[0]
y_length = y_range[1] - y_range[0]

### frame length ###

frame_length = 8

interval = int(len(frames_1) / frame_length)
display_frames = np.ones((x_length * 3, y_length * frame_length ,3), dtype=np.uint8)

for index in range(frame_length):
    
    clip_frame = frames_1[index * interval][x_range[0]:x_range[1], y_range[0]:y_range[1], :]
    display_frames[x_length * 0: x_length * 1, y_length * index : y_length * (index+1), : ] = clip_frame
    
    clip_frame = frames_2[index * interval][x_range[0]:x_range[1], y_range[0]:y_range[1], :]
    display_frames[x_length * 1: x_length * 2, y_length * index : y_length * (index+1), : ] = clip_frame
    
    clip_frame = frames_3[index * interval][x_range[0]:x_range[1], y_range[0]:y_range[1], :]
    display_frames[x_length * 2: x_length * 3, y_length * index : y_length * (index+1), : ] = clip_frame
    
bgr_img = cv2.cvtColor(display_frames, cv2.COLOR_RGB2BGR)
cv2.imwrite(f"imgs/{file_name}", bgr_img)
