import json
from tqdm import tqdm
from argparse import ArgumentParser
import os
import torch
from torch.utils.tensorboard import SummaryWriter
from actor.random_actor import RandomActor
from actor.chat_actor import ChatActor
from actor.logit_actor import LogitActor
from envs.lang_env import LangEnv
from utils.nle_utils import TASK_TO_DESC
from models import *
from omegaconf import OmegaConf

import copy
from envs.action_dict import action_dict, GAME_DESCRIPTION
import datetime
import time

import warnings
warnings.filterwarnings('ignore')

EMPTY = np.empty(1, dtype=object)

def process_input(frame, timestep):
    inputs = dict()
    time_array = frame[:timestep]
    
    for dic in time_array:
        for key in dic.keys():
            if key not in inputs.keys():
                inputs[key] = []
            inputs[key].append(torch.tensor(dic[key])) 
    
    for key in inputs:
        tensors_to_cat = [t.unsqueeze(0).unsqueeze(0) for t in inputs[key]]  
        combined_tensor = torch.cat(tensors_to_cat, dim=0)  
        inputs[key] = combined_tensor.to(device)
        
    return inputs

def train(task, env, actor, buffer, rollout_id, controller, timestep, logger=None, cumulative_metrics=None):
    # Reset environment and other variables
    
    description = env.get_task()
    obs, lang_obs_list = env.reset()
    
    obs["done"] = False
    obs_copy = copy.copy(obs)
    frame = [obs_copy for _ in range(timestep)]

    actor.reset(description)
    buffer.clear()
    
    cum_reward = 0
    steps = 0
    successes = 0
    done = False
    
    if rollout_id < args.num_rollouts // 10:
        Imitation = True
    else:
        Imitation = False
        
    while not done:
        
        ## debug
        
        if args.max_episode_steps is not None and steps >= args.max_episode_steps:
            done = True
            break
        
        core_state = controller.model.initial_state(batch_size=config.batch_size)
        lang_actions, env_actions = env.get_actions()
        
        env_action, probs = actor.get_action(
                    lang_obs_list, 
                    lang_actions, 
                    env_actions, 
                    return_tuple=False
                )
        
        inputs = process_input(frame, timestep)
        for i in range(len(core_state)) :
            core_state[i].to(device)

        output = controller(inputs, core_state)
        result, next_core_state = output
        dist = result["policy_logits"]
        value = result["baseline"][-1]
        action = result["action"][-1]

        controller_logits = dist.logits
        controller_log_prob = torch.softmax(controller_logits, dim=-1)
        controller_log_prob = torch.log(controller_log_prob)
        
        
        # print(f"model time = {e - s}")
        
        s = time.time()

        if Imitation:
            meta_controller_value = controller.value_model(inputs, core_state)
            
            meta_controller_probs = probs
            meta_controller_tensor = meta_controller_probs.clone().detach().to(device)
            meta_controller_logits = torch.log(meta_controller_tensor + 1e-8)
            meta_controller_probs_new = torch.softmax(meta_controller_logits, dim=0)
            
            sample_prob = 0.1 * controller_logits + 1 * meta_controller_logits
            soft_prob = torch.softmax(sample_prob.squeeze(), dim=-1)
            action = torch.multinomial(soft_prob, 1)
            env_action = action
        else:
            env_action = dist.sample()
            meta_controller_probs = torch.tensor([0])
            meta_controller_value = torch.tensor([0])
            meta_controller_probs_new = torch.tensor([0])
            
        e = time.time()
        
        # print(f"imitation time = {e - s}")
        
        log_probs = dist.log_prob(action)
        env_action = action_dictionary[env_action.item()]
        
        if not isinstance(env_action, list):
            env_action = [env_action]
            
        
        s = time.time()
        glyphs = obs["glyphs"]
        target_positions = [(y, x) for y in range(glyphs.shape[0]) for x in range(glyphs.shape[1]) if glyphs[y, x] == 2383]
        for idx, a in enumerate(env_action):
            if Imitation :
                token += count_tokens(a)
            next_obs, lang_obs_list, reward, done, info = env.step(a)
            # additional reward
            
            env.render()

            if done:
                break

            player_x, player_y = next_obs["blstats"][0], next_obs["blstats"][1]
            glyphs = next_obs["glyphs"]
            try : 
                reward += -0.05 * (abs(player_x - target_positions[0][1]) + abs(player_y - target_positions[0][0]))
            except :
                reward += 0

            reward += -0.001 * steps

            print("reward = ", reward)

            cum_reward += reward
            steps += 1
            obs = next_obs
            
        e = time.time()
        # print(f"step time = {e - s}")
        # env.render()
        
        next_obs["done"] = done
        
        frame.append(next_obs)
        frame = frame[1:]
        
        core_state = next_core_state
        for i in range(len(core_state)) :
            core_state[i].to("cpu")
        
        buffer.store(
            inputs, 
            process_input(frame, timestep),
            action.cpu().detach().numpy(), 
            torch.tensor(cum_reward).cpu(), 
            value.cpu().detach().numpy(), 
            log_probs.cpu().detach().numpy(),
            controller_logits.cpu().detach().numpy(), 
            controller_log_prob.cpu().detach().numpy(),
            meta_controller_probs.cpu().detach().numpy(),
            meta_controller_value.cpu().detach().numpy(),
            meta_controller_probs_new.cpu().detach().numpy(),
            done,
            core_state
        )
        if done:
            break

    if done:
        value = 0.
    else:
        output = controller(inputs, core_state)
        value = output[0]["baseline"][-1]
        
    buffer.finish_path(last_val=value)

    results[task]["reward"] += cum_reward / args.num_rollouts
    if reward > 0:
        results[task]["success"] += 1 / args.num_rollouts
        successes += 1
    elif "end_status" in info and info["end_status"] == 1:
        results[task]["death"] += 1 / args.num_rollouts
    pbar.update(1)
    pbar.set_description("Successes {}/{}".format(int(results[task]["success"] * args.num_rollouts), rollout_id + 1))
        
    # Accumulate metrics
    cumulative_metrics['total_reward'] += cum_reward
    cumulative_metrics['total_successes'] += successes
    cumulative_metrics['total_steps'] += steps
    
    start_time = time.time()
    if Imitation:
        mean_losses = controller.update_network(buffer)
    else:
        mean_losses = controller.update_policy(buffer)
    end_time = time.time()
    
    # Log the metrics after every rollout
    log_metrics(rollout_id + 1, rollout_id + 1, logger, cumulative_metrics, token, loss = mean_losses, prefix="train")
    print("Training End. Use time ", end_time - start_time, ". Loss = ", mean_losses)
    
    return token

def evaluate(env, rollout_id, controller, timestep, logger=None, cumulative_metrics=None, test_num = 10):
    successes = 0
    for _ in range(test_num) :
        result = env.reset()

        obs, lang_obs_list = env.reset()
        obs["done"] = False            
        obs_copy = copy.copy(obs)
        frame = [obs_copy for _ in range(timestep)]
        
        cum_reward = 0
        steps = 0
        done = False

        while not done:
            # env.render()
            
            core_state = controller.model.initial_state(batch_size=config.batch_size)
            
            inputs = process_input(frame, timestep)
            
            output = controller(inputs, core_state)
            result, next_core_state = output
            dist = result["policy_logits"]

            controller_logits = dist.logits
            controller_log_prob = torch.softmax(controller_logits, dim=-1)
            controller_log_prob = torch.log(controller_log_prob)
            
            env_action = dist.sample()
            env_action = action_dictionary[env_action.item()]
            
            if not isinstance(env_action, list):
                env_action = [env_action]

            for idx, a in enumerate(env_action):
                next_obs, lang_obs_list, reward, done, info = env.step(a)
                cum_reward += reward
                steps += 1
                obs = next_obs
                
                if done:
                    break
            
            next_obs["done"] = done
            frame.append(next_obs)
            frame = frame[1:]
            
            core_state = next_core_state
            
            if done:
                if reward > 0:
                    successes += 1
                break

    # Accumulate metrics
    cumulative_metrics['total_reward'] += cum_reward
    # cumulative_metrics['total_successes'] += successes
    successes_rate = successes / test_num
    print("*" * 30, "\nevaluate end, successes rate = ", successes_rate)
    time.sleep(1)
    cumulative_metrics['total_steps'] += steps

    # Log the metrics after every rollout
    log_metrics(rollout_id + 1, rollout_id + 1, logger, cumulative_metrics, successes_rate = successes_rate, prefix="test")

    # results[task]["reward"] += cum_reward / args.num_rollouts
    # if reward > 0:
    #     results[task]["success"] += 1 / args.num_rollouts
    # elif "end_status" in info and info["end_status"] == 1:
    #     results[task]["death"] += 1 / args.num_rollouts
    
    # pbar.update(1)
    # pbar.set_description("Successes {}/{}".format(int(successes), rollout_id + 1))

    if args.max_episode_steps is not None and steps >= args.max_episode_steps:
        done = True

def log_metrics(rollout_id, num_rollouts, logger, cumulative_metrics, token = 0, prefix="train", loss = 0, successes_rate = 1):
    avg_reward = cumulative_metrics['total_reward'] / num_rollouts
    avg_success_rate = cumulative_metrics['total_successes'] / num_rollouts
    avg_steps = cumulative_metrics['total_steps'] / num_rollouts

    if logger is not None:
        logger.add_scalar(f"{prefix}/Average Reward", avg_reward, rollout_id)
        logger.add_scalar(f"{prefix}/Average Success Rate", avg_success_rate, rollout_id)
        logger.add_scalar(f"{prefix}/Evaluate Success Rate", successes_rate, rollout_id)
        logger.add_scalar(f"{prefix}/Average Steps", avg_steps, rollout_id)
        logger.add_scalar(f"{prefix}/Token", token, rollout_id)
        logger.add_scalar(f"{prefix}/Loss", loss, rollout_id)

def create_logger(path, subdir="train"):
    """
    Create a logger that saves logs to the specified path.

    Parameters:
    - path (str): The base directory where logs will be stored.
    - subdir (str): Subdirectory to store logs (e.g., "train" or "test").

    Returns:
    - logger (SummaryWriter): A TensorBoard SummaryWriter object for logging.
    """
    full_path = os.path.join(path, subdir)
    os.makedirs(full_path, exist_ok=True)

    logger = SummaryWriter(full_path, flush_secs=0.1)
    print(f"Logging to {full_path}")

    logger.dir = full_path
    
    return logger

if __name__ == "__main__":
    parser = ArgumentParser(description="Generate rollout data")
    parser.add_argument("--exp_name", type=str, default="test", help="File name for saves")
    parser.add_argument("--task", type=str, default="", help="Task to evaluate on, default is all tasks")
    parser.add_argument("--actor", type=str, default="random", help="Can be random, gpt, or a path to a seq2seq huggingface model")
    parser.add_argument("--num_rollouts", type=int, default=10, help="Number of rollouts to evaluate")
    parser.add_argument("--max_episode_steps", type=int, default=150, help="Max episode steps")
    parser.add_argument("--fewshot", type=int, default=4, help="How many fewshot examples to use for gpt")
    parser.add_argument("--action_temp", type=float, default=1, help="Sampling temperature for action policy")
    parser.add_argument("--cot", action="store_true", help="Use explanations for actor")
    parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
    args = parser.parse_args()

    ACTIONS_NUM = 0
    
    try:
        action_dictionary = action_dict[args.task]
        ACTIONS_NUM = len(action_dictionary.keys())
    except:
        raise ValueError("Invalid task name.")

    # device = "cpu" if args.cpu else "cuda"
    device = "cuda"

    config = OmegaConf.load("config.yaml")

    logger_path = "log/IHAC"
    current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')  
      

    path = os.path.dirname(os.path.abspath(__file__))
    train_logger = create_logger(os.path.join(path, logger_path, args.task, current_time), subdir="train")
    test_logger = create_logger(os.path.join(path, logger_path, args.task, current_time), subdir="test")

    if args.actor == "random":
        actor = RandomActor()
    elif args.actor == "gpt":
        actor = ChatActor(fewshot=args.fewshot, use_cot=args.cot)
    else:
        actor = LogitActor(args.actor, temperature=args.action_temp)

    if args.task:
        tasks = [args.task]
    else:
        tasks = TASK_TO_DESC.keys()

    results = {
        x: dict(reward=0, success=0, death=0) 
        for x in tasks
    }
    
    for task in tasks:
        env = LangEnv(task)
        buffer = Buffer()
        config.num_actions = ACTIONS_NUM
        controller = PPO(config)
        
        controller.model.to(device)
        controller.value_model.to(device)
        
        timestep = config.timestep
        test_iter = 0
        
        print("Starting Task:", task)
        pbar = tqdm(range(args.num_rollouts))

        cumulative_metrics_train = {'total_reward': 0, 'total_successes': 0, 'total_steps': 0}
        cumulative_metrics_test = {'total_reward': 0, 'total_successes': 0, 'total_steps': 0}

        for rollout_id in range(args.num_rollouts):
            
            # if rollout_id > args.num_rollouts // 10 : actor = RandomActor()
            train(args.task, env, actor, buffer, rollout_id, controller, timestep, logger=train_logger, cumulative_metrics=cumulative_metrics_train)

            if rollout_id % config.test_step == 0:
                evaluate(env, test_iter, controller, timestep, logger=test_logger, cumulative_metrics=cumulative_metrics_test)
                test_iter += 1

        # Log final average metrics after all rollouts
        log_metrics(rollout_id + 1, args.num_rollouts, train_logger, cumulative_metrics_train, prefix="train")
        log_metrics(rollout_id + 1, args.num_rollouts, test_logger, cumulative_metrics_test, prefix="test")

        with open(args.exp_name + ".json", "w") as f:
            json.dump(results, f, indent=4)
