import argparse
import json
import threading
import time

import cv2
import numpy as np
import torch
from models.mlp import TrajectoryMLP
from moving_out.evaluation import MovingOutEvaluator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class mlp_warpper:
    def __init__(self, model_path, another_model) -> None:
        self.model = torch.load(model_path, map_location=device)
        self.another_model = another_model
        self.model = self.model.eval()
        self.another_model = self.another_model.eval()

    def get_input(self, states, past_states):
        states = torch.tensor(states, dtype=torch.float32).to(device).unsqueeze(0)
        past_states = (
            torch.tensor(past_states, dtype=torch.float32).to(device).unsqueeze(0)
        )

        another_mse_output, another_ce_logits, inner_state = self.another_model(
            past_states, states, return_inner_state=True
        )
        # another_ce_logits = another_ce_logits * 0.5
        # actions_probabilities = torch.softmax(another_ce_logits, dim=-1)
        # predicted_actions = torch.multinomial(actions_probabilities[:,0,:], 1)
        # another_mse_output = another_mse_output[:,0,:]
        # another_state = torch.cat((another_mse_output, predicted_actions), dim = 1)
        another_state = inner_state
        return states, past_states, another_state

    def get_action(self, states, past_states):
        # print(len(states))
        # print(len(past_states))
        # exit(0)
        states, past_states, another_state = self.get_input(states, past_states)
        mse_output, ce_logits = self.model(past_states, states, another_state)
        mse_output = torch.clip(mse_output, -1, 1)
        if True:
            ce_logits = ce_logits * 1
            actions_probabilities = torch.softmax(ce_logits, dim=-1)
            predicted_actions = torch.multinomial(
                actions_probabilities[0], 1
            ).unsqueeze(0)
        else:
            predicted_actions = torch.argmax(actions_probabilities, dim=-1)

        predicted_actions = predicted_actions.squeeze(2)
        predicted_actions = predicted_actions.squeeze(0)
        mse_output = mse_output.squeeze(0)
        return mse_output, predicted_actions


def main(ids, model, evaluation_times, precition_horizon, file_name):
    evaluator = MovingOutEvaluator()
    evaluation_results = evaluator.evaluate_ids(
        ids,
        model,
        evaluate_times=evaluation_times,
        max_steps=2000,
        model_horizon=precition_horizon,
        file_name=file_name,
    )
    print(evaluation_results)


if __name__ == "__main__":
    procedure = """
    ===================================================
            Experiment Participation Procedure
    ===================================================
    
    Welcome to participate in our experiment. This 
    experiment aims to study the behavior when people 
    play the game 'Moving Out' with an AI agent.
    
    Operation Instructions:
    You need to use a joystick to control the robot's 
    movement and item handling. 
    Use the joystick to control the direction of the 
    robot's movement and hold the R button to grab or 
    release items.
    
    After the study, you will receive a questionnaire. 
    Please answer the questions based on your experience.
    
    Thank you!
    ===================================================
    """
    parser = argparse.ArgumentParser(description="Process JSON file and ID.")
    parser.add_argument(
        "--id_number", type=int, default=0, nargs="+", help="The ID number to extract"
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default="/p/langdiffuse/moving_out_magical/moving_out_AI/baselines/only_0_all.pt",
        help="model path",
    )
    parser.add_argument(
        "--another_model_save_path", type=str, help="The ID number to extract"
    )

    parser.add_argument(
        "--evaluation_times", type=int, default=1, help="evaluation_times"
    )
    parser.add_argument(
        "--precition_horizon", type=int, default=1, help="evaluation_times"
    )
    try:
        args = parser.parse_args()
    except:
        args = parser.parse_args([])

    another_model = torch.load(args.another_model_save_path)
    model = mlp_warpper(args.model_path, another_model)
    file_name = str(args.model_path)
    main(
        args.id_number, model, args.evaluation_times, args.precition_horizon, file_name
    )
