import os
import time
import wandb
import torch
import json
import numpy as np

from embodied_cd.common import argparse
from embodied_cd.common.print_utils import *
from embodied_cd.common.env_utils import *
from embodied_cd.trl.ecoc import ThinkWholeTrainer
from embodied_cd.agents.ecoc import ThinkWholeAgent

from scripts.train import *


### ==================================###
#           Phase 1: Policy             #
### ==================================###
def prepare(args, path):
    tokenizer, model = prepare_model(args.pi_model_name, init=True)
    model = add_adapter(model, name="reasoning_policy")
    model = add_adapter(model, name="planning_policy")
    if path is not None:
        model.load_adapter(path + f'/policy/reasoning_policy', adapter_name='reasoning_policy')
        model.load_adapter(path + f'/policy/planning_policy', adapter_name='planning_policy')
    return tokenizer, model

def train_policy(args):
    # wandb log on/off
    if args.wandb_off:
        wandb.init(mode="disabled")
    else:
        wandb.init(project="Continual_Learning_by_Learning_to_Self_Correct")
        wandb.run.name = "ECoC-Think-Whole"

    tokenizer = AutoTokenizer.from_pretrained(args.pi_model_name)
    # prepare dataset
    dataset, collator = prepare_dataset(
        args.env_name, 
        args.pi_model_name, 
        args.dataset_dir, 
        tokenizer, 
        dataset_type="cd-think",
        dataset_size=args.dataset_size,
        few_shot_example=None,
        ablation=args.dataset_ablation,
    )
    del tokenizer

    # load configs and get continual phase
    config = {}
    with open(f"./configs/{args.env_name}/{args.continual_type}_train.py", "r") as f:
        code = f.read()
    exec(code, config)
    phase = config["sequences"][args.continual_seq].index(args.continual_set)
    
    if (phase == 0) or (args.seqft == 0):
        base_tokenizer, base_model = None, None
        # prepare model
        tokenizer, model = prepare(args, path=None)
        prev_phase = None
    else:
        prev_phases = config["sequences"][args.continual_seq][:phase]
        scores = [] 
        if len(prev_phases) == 1:
            prev_phase = prev_phases[0]
        else:
            for prev_phase in prev_phases:
                # load model from previous phase
                path = f"./results/ECoC-Fine/Think-Whole/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_{args.prev_tag}"
                tokenizer, model = prepare(args, path=path)
                score = ThinkWholeTrainer._setup_base_model(args.pi_model_name, tokenizer, model, dataset, args.max_think_token)
                scores.append(score)
                del tokenizer, model

            print(prev_phases)
            print(scores)
            index = np.argmax(scores)
            prev_phase = prev_phases[index]

        print_pass(f"[Loading Base Model] at {prev_phase}")
        path = f"./results/ECoC-Fine/Think-Whole/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_{args.prev_tag}"
        print(path)
        base_tokenizer, base_model = prepare(args, path=path)
        # prepare base model
        if "BaseOnly" in args.tag:
            print_error("Use base rationale generation, but initialize LoRA")
            tokenizer, model = prepare(args, path=None) 
        else:
            print_error("Initialize with base reasoning policy")
            tokenizer, model = prepare(args, path=path) 

    # set output_dir
    output_dir = args.output_dir + "/policy"

    # set configs
    params = {
        "total_epochs": 20,
        "inner_epochs": 4,
        "learning_rate": 5e-5,
        "batch_size": 4,
        "max_think_token": args.max_think_token,
        "ablation": args.ablation, # 0: no ablation, 1: only generation, 2: only base correction, 3: only self correction
        "total_think": 5,
        "base_continual_set": prev_phase, # last continual phase
        "base_continual_tag": args.prev_tag,
        "dataset_size": args.dataset_size,
        "dataset_ablation": args.dataset_ablation,
    }

    # create trainer
    trainer = ThinkWholeTrainer(
        args.env_name,
        args.pi_model_name,
        base_tokenizer,
        base_model,
        tokenizer,
        model,
        dataset,
        output_dir,
        **params,
    )

    # save params
    save_params(output_dir, trainer.params)

    print_warn("[ECoC] Policy Training Start!")
    #model.print_trainable_parameters()
    trainer.train()

    # save model
    trainer.save_pretrained(output_dir)


def evaluate(args):
    # prepare base model
    params = load_params(args.output_dir + '/policy')
    if "base_continual_set" not in params.keys() or params["base_continual_set"] is None:
        base_tokenizer, base_model = None, None
    else:
        prev_phase, prev_tag = params["base_continual_set"], params["base_continual_tag"]
        print_warn(f"[Loading Base Model at: {prev_phase}")
        path = f"./results/ECoC-Fine/Think-Whole/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_{prev_tag}"
        base_tokenizer, base_model = prepare(args, path=path)

    # prepare model
    tokenizer, model = prepare(args, path=args.output_dir)

    # prepare agent
    agent = ThinkWholeAgent(
        args.pi_model_name,
        base_model,
        base_tokenizer,
        model, 
        tokenizer, 
        plan_model=None,
        plan_tokenizer=None,
        rag_pipe=None,
        env_name=args.env_name,
        cl_type=args.continual_type.split('_')[0],
        max_think_token=args.max_think_token, 
        correct=args.no_correct,
        no_critic=args.no_critic,
        #few_shot_example=few_shot_example,
        perturb=args.perturb,
        decoding_strategy="greedy",
    )
    
    if args.continual_eval_set is None:
        flag = False
        args.continual_eval_set = args.continual_set
    else:
        flag = True

    # make env
    split_path = f"./externals/cl-alfred/embodied_split/{args.continual_type.split('_')[0]}_il/embodied_data_disjoint_rand1_cls1_task{int(args.continual_eval_set.split('_')[-1])-1}.json" # for alfred
    env = make_env(args.env_name, args.num_topk_edge, args.env_ip, args.seed, args.env_port, split_path)
    
    if args.evaluate == "seen" or args.evaluate == "both":
        # Evaluate on Train Environment
        env_list, task_list, available_tasks, ps_line = load_evaluation_configs(
            args.env_name, args.continual_type, args.continual_eval_set, "seen")
        
        print_warn(f"Evaluating in Seen Set")
        if not flag:
            save_path = args.output_dir + '/results/evaluation_seen'
        else:
            save_path = args.output_dir + f'/results/evaluation_seen_{args.continual_eval_set}'
        if args.perturb:
            save_path = save_path + '_perturb'
        save_path += '.txt'
        evaluation_pipe(
            env, 
            agent, 
            env_list, 
            task_list, 
            available_tasks,
            ps_line,
            total_episode=args.eval_episode, 
            save_path=save_path
        )
    if args.evaluate == "unseen" or args.evaluate == "both": 
        # Evaluate on Test Environment
        env_list, task_list, available_tasks, ps_line = load_evaluation_configs(
            args.env_name, args.continual_type, args.continual_eval_set, "unseen")
        
        if not flag:
            save_path = args.output_dir + '/results/evaluation_unseen'
        else:
            save_path = args.output_dir + f'/results/evaluation_unseen_{args.continual_eval_set}'
        if args.perturb:
            save_path = save_path + '_perturb'
        save_path += '.txt'
        evaluation_pipe(
            env, 
            agent, 
            env_list, 
            task_list, 
            available_tasks,
            ps_line,
            total_episode=args.eval_episode, 
            save_path=save_path
        )


if __name__ == "__main__":
    args = argparse.parse_args()
    args.max_think_token = 180
    # directory settings
    if args.dataset_dir is None:
        args.dataset_dir = f"./datasets/{args.env_name}/{args.continual_type}_train/{args.continual_set}"
    if args.output_dir is None:
        args.output_dir = f"./results/ECoC-Fine/Think-Whole/{args.env_name}/{args.continual_type}_{args.continual_seq}/{args.continual_set}_{args.pi_model_name.split('/')[-1]}_{args.tag}"

    if args.prev_tag is None:
        args.prev_tag = args.tag

    if args.train:
        train_policy(args)
    if args.evaluate:
        evaluate(args)
