import os
import time
import wandb
import torch
import json
import numpy as np
import tqdm

from embodied_cd.common import argparse
from embodied_cd.common.print_utils import *
from embodied_cd.common.env_utils import *
from embodied_cd.common.rag_utils import RAGPipeline
from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.agents.ecoc import CAMAAgent

from scripts.train import *


def prepare(args, path):
    tokenizer, model = prepare_model(args.pi_model_name, init=True)
    model = add_adapter(model, name="reasoning_policy")
    model = add_adapter(model, name="planning_policy")
    if path is not None:
        model.load_adapter(path + f'/policy/reasoning_policy', adapter_name='reasoning_policy')

        planning_path = args.output_dir.replace("CAMA-Think-v2", "CAMA-Think") + "/actor"
        model.load_adapter(planning_path, adapter_name='planning_policy')
    return tokenizer, model


def evaluate(args):
    # load configs and get continual phase
    config = {}
    with open(f"./configs/{args.env_name}/{args.continual_type}_train.py", "r") as f:
        code = f.read()
    exec(code, config)
    phase = config["sequences"][args.continual_seq].index(args.continual_set)

    if phase == 0: tag = "A4SeqFT_0"
    # else: tag = "A4Append_0"
    else: tag = "A4SeqFT_S5P"

    # prepare reasoning-model
    reasoning_output_dir = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{args.continual_set}_{args.pi_model_name.split('/')[-1]}_A4SeqFT_S5P"
    tokenizer, model = prepare(args, path=reasoning_output_dir)

    # prepare agent
    agent = CAMAAgent(
        args.pi_model_name,
        base_model=None, # you can use this model as planning-policy
        base_tokenizer=None,
        model=model, 
        tokenizer=tokenizer, 
        plan_model=None,
        plan_tokenizer=None,
        rag_pipe=None,
        env_name=args.env_name,
        cl_type=args.continual_type.split('_')[0],
        max_think_token=args.max_think_token, 
        total_think=args.total_think,
        correct=args.correct,
        no_critic=args.no_critic,
        #few_shot_example=few_shot_example,
        perturb=args.perturb,
        decoding_strategy="greedy",
    )

    if args.continual_eval_set is None:
        flag = False
        args.continual_eval_set = args.continual_set
    else:
        flag = True

    # make env
    split_path = f"./externals/cl-alfred/embodied_split/{args.continual_type.split('_')[0]}_il/embodied_data_disjoint_rand1_cls1_task{int(args.continual_eval_set.split('_')[-1])-1}.json" # for alfred
    env = make_env(args.env_name, args.num_topk_edge, args.env_ip, args.seed, args.env_port, split_path)
    
    if args.evaluate == "seen" or args.evaluate == "both":
        agent.evaluate = 'seen'
        # Evaluate on Train Environment
        env_list, task_list, available_tasks, ps_line = load_evaluation_configs(
            args.env_name, args.continual_type, args.continual_eval_set, "shot5_seen")
        
        print_warn(f"Evaluating in Seen Set")
        if not flag:
            save_path = args.output_dir + '/results/evaluation_seen'
        else:
            save_path = args.output_dir + f'/results/evaluation_seen_{args.continual_eval_set}'
        if args.perturb:
            save_path = save_path + '_perturb'
        if args.test_time_control:
            save_path = save_path + '_control'
        if args.total_think != 5:
            save_path = save_path + f'_think{args.total_think}'
        if args.correct:
            save_path = save_path + '_correct'
        save_path += '.txt'
        evaluation_pipe(
            env, 
            agent, 
            env_list, 
            task_list, 
            available_tasks,
            ps_line,
            total_episode=args.eval_episode, 
            save_path=save_path
        )
    if args.evaluate == "unseen" or args.evaluate == "both": 
        agent.evaluate = 'unseen'
        # Evaluate on Test Environment
        env_list, task_list, available_tasks, ps_line = load_evaluation_configs(
            args.env_name, args.continual_type, args.continual_eval_set, "shot5_unseen")
        
        if not flag:
            save_path = args.output_dir + '/results/evaluation_unseen'
        else:
            save_path = args.output_dir + f'/results/evaluation_unseen_{args.continual_eval_set}'
        if args.perturb:
            save_path = save_path + '_perturb'
        if args.test_time_control:
            save_path = save_path + '_control'
        if args.total_think != 5:
            save_path = save_path + f'_think{args.total_think}'
        if args.correct:
            save_path = save_path + '_correct'
        save_path += '.txt'
        evaluation_pipe(
            env, 
            agent, 
            env_list, 
            task_list, 
            available_tasks,
            ps_line,
            total_episode=args.eval_episode, 
            save_path=save_path
        )


if __name__ == "__main__":
    args = argparse.parse_args()
    args.max_think_token = 40
    # directory settings
    if args.dataset_dir is None:
        args.dataset_dir = f"./datasets/{args.env_name}/{args.continual_type}_train/{args.continual_set}"
    if args.output_dir is None:
        output_dir = f"./results/CAMA-Think-v2/{args.env_name}/{args.continual_type}_{args.continual_seq}/{args.continual_set}/{args.pi_model_name.split('/')[-1]}"
        args.output_dir = os.path.join(output_dir, args.tag)

    if args.prev_tag is None:
        args.prev_tag = args.tag

    if args.evaluate:
        evaluate(args)
