import json
from tqdm import tqdm
from argparse import ArgumentParser

from actor.random_actor import RandomActor
from actor.chat_actor import ChatActor
from actor.logit_actor import LogitActor
from envs.lang_env import LangEnv
from utils.nle_utils import TASK_TO_DESC
from torch.distributions import Categorical
from models import *
from omegaconf import OmegaConf
import copy
from envs.action_dict import action_dict

from skills.lava_cross import LavaCrossWithPotionBehaviorTree
from skills.WoD import WoDBehaviorTree

import warnings
warnings.filterwarnings('ignore')

EMPTY = np.empty(1, dtype=object)

def process_input(frame, timestep) :
    
    inputs = dict()
    time_array = frame[:timestep]
    
    for dic in time_array :
        for key in dic.keys() :
            if key not in inputs.keys() : inputs[key] = []
            inputs[key].append(torch.tensor(dic[key])) 
    
    for key in inputs:  
        tensors_to_cat = [t.unsqueeze(0).unsqueeze(0) for t in inputs[key]]  
        combined_tensor = torch.cat(tensors_to_cat, dim=0)  
        inputs[key] = combined_tensor
        
    return inputs

if __name__ == "__main__":
    parser = ArgumentParser(description="Generate rollout data")
    parser.add_argument("--exp_name", type=str, default="test", help="File name for saves")
    parser.add_argument("--task", type=str, default="", help="Task to evaluate on, default is all tasks")
    parser.add_argument("--actor", type=str, default="random", help="Can be random, gpt, or a path to a seq2seq huggingface model")
    parser.add_argument("--num_rollouts", type=int, default=10, help="Number of rollouts to evaluate")
    parser.add_argument("--max_episode_steps", type=int, default=50, help="Max episode steps")
    parser.add_argument("--fewshot", type=int, default=4, help="How many fewshot examples to use for gpt")
    parser.add_argument("--action_temp", type=float, default=1, help="Sampling temperature for action policy")
    parser.add_argument("--cot", action="store_true", help="Use explanaitons for actor")
    parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
    # parser.add_argument("--timestep", type=int, default=5, help="time step for LSTM")
    args = parser.parse_args()

    ACTIONS_NUM = 0
    
    try:
        action_dictionary = action_dict[args.task]
        ACTIONS_NUM = len(action_dictionary.keys())
    except:
        raise ValueError("Invalid task name.")

    device = "cpu" if args.cpu else "cuda"

    config = OmegaConf.load("config.yaml")

    if args.actor == "random":
        actor = RandomActor()
    elif args.actor == "gpt":
        actor = ChatActor(fewshot=args.fewshot, use_cot=args.cot)
    else:
        actor = LogitActor(args.actor, temperature=args.action_temp)

    if args.task:
        tasks = [args.task]
    else:
        tasks = TASK_TO_DESC.keys()

    results = {
        x: dict(reward=0, success=0, death=0) 
        for x in tasks
    }
    
    for task in tasks:
        env = LangEnv(task)
    
        buffer = Buffer()
        
        config.num_actions = ACTIONS_NUM
        
        controller = PPO(config)

        timestep = config.timestep
        
        print("Starting Task:", task)
        pbar = tqdm(range(args.num_rollouts))
        for rollout_id in range(args.num_rollouts):

            result = env.reset()
            description = env.get_task()
            
            obs, lang_obs_list = env.reset()
            obs["done"] = False            
            obs_copy = copy.copy(obs)
            frame = [obs_copy for _ in range(timestep)]

            actor.reset(description)
            buffer.clear()
            
            cum_reward = 0
            steps = 0
            done = False
            
            if rollout_id < args.num_rollouts // 10:
                Imitation = True
            else:
                Imitation = False
            
            bt = WoDBehaviorTree()
            
            while not done:
                
                # initialization #
                core_state = controller.model.initial_state(batch_size = config.batch_size)
                lang_actions, env_actions = env.get_actions()

                # env_action, probs = actor.get_action(
                #     lang_obs_list, 
                #     lang_actions, 
                #     env_actions, 
                #     return_tuple=False
                # )
                # # print(env_actions)
                
                # # Process meta logits #
                # _prob = [.0] * ACTIONS_NUM
                # for i, action in enumerate(env_actions):
                #     # print("action = ", action)
                #     if action in action_dictionary.values():
                #         for keys in action_dictionary.keys() :
                #             if action_dictionary[keys] == action :
                #                 _prob[keys] = probs[i]
                # _prob[-1] = torch.tensor(_prob[-1])
                # probs = torch.stack(_prob)                       
                # # if not isinstance(probs, list):
                # #     probs = [probs]
                
                # inputs = process_input(frame, timestep)
                
                # step #
                # output = controller(inputs, core_state)
                # meta_controller_value = controller.value_model(inputs, core_state)
                # result, next_core_state = output
                # dist = result["policy_logits"]
                # value = result["baseline"][-1]
                # action = result["action"][-1]

                # controller_logits = dist.logits
                # # controller_logits +=  1e-8
                # controller_log_prob = torch.softmax(controller_logits, dim=-1)
                # controller_log_prob = torch.log(controller_log_prob)
                
                # meta_controller_probs = probs
                # meta_controller_tensor = meta_controller_probs.clone().detach().to(device)
                # meta_controller_logits = torch.log(meta_controller_tensor + 1e-8)
                # meta_controller_probs_new = torch.softmax(meta_controller_logits, dim=0)
                
                # if Imitation:
                #     # print(f"controller = {controller_logits}, {controller_logits.shape}\nmeta = {meta_controller_logits}, {meta_controller_logits.shape}")
                #     sample_prob = controller_logits + 0.9 * meta_controller_logits.cpu()
                #     soft_prob = torch.softmax(sample_prob.squeeze(), dim = -1)
                #     action = torch.multinomial(soft_prob, 1)
                #     log_probs = dist.log_prob(action)
                #     env_action = action
                # else :
                #     env_action = dist.sample()
                
                # env_action = action_dictionary[env_action.item()]
                env_action, _ = bt.step(obs)
                
                if not isinstance(env_action, list):
                    env_action = [env_action]

                for idx, a in enumerate(env_action):

                    next_obs, lang_obs_list, reward, done, info = env.step(a)
                    cum_reward += reward
                    steps += 1
                    obs = next_obs
                    
                    if done:
                        break
                    
                env.render()    
                
                # opt_time = time.time() - optimizer_start
                
                next_obs["done"] = done
                frame.append(next_obs)
                frame = frame[1:]
                
                next_inputs = process_input(frame, timestep)
                
                obs = next_obs
                
                # core_state = next_core_state
                
                # buffer.store(
                #             inputs, 
                #             next_inputs,
                #             action, 
                #             torch.tensor(cum_reward), 
                #             value.to("cpu").detach().numpy(), 
                #             log_probs.to("cpu").detach().numpy(),
                #             controller_logits.to("cpu").detach().numpy(), 
                #             controller_log_prob.to("cpu").detach().numpy(),
                #             meta_controller_probs.to("cpu").detach().numpy(),
                #             meta_controller_value.to("cpu").detach().numpy(),
                #             meta_controller_probs_new.to("cpu").detach().numpy(),
                #             done,
                #             core_state
                # )
                
                if done :
                    break

            if done:
                value = 0.
            else:
                output = controller(inputs, core_state)
                value = output[0]["baseline"][-1]
                
            buffer.finish_path(last_val=value)
                    
            results[task]["reward"] += cum_reward / args.num_rollouts
            
            # if reward < 0 :
            #     raise("The shit algorithm let me fail!")
            
            if reward > 0:
                results[task]["success"] += 1 / args.num_rollouts
            elif "end_status" in info and info["end_status"] == 1:
                results[task]["death"] += 1 / args.num_rollouts
            pbar.update(1)
            pbar.set_description("Successes {}/{}".format(int(results[task]["success"] * args.num_rollouts), rollout_id + 1))
                

            if args.max_episode_steps is not None and steps >= args.max_episode_steps:
                done = True
                    
            if Imitation :
                mean_losses = controller.update_network(buffer)
            else :
                mean_losses = controller.update_policy(buffer)
            print("Training End...")

        with open(args.exp_name + ".json", "w") as f:
            json.dump(results, f, indent=4)
