import os
import wandb
import torch
import json
import numpy as np
from trl import SFTConfig, SFTTrainer

from embodied_cd.common import argparse
from embodied_cd.common.print_utils import *
from embodied_cd.common.env_utils import *
from scripts.train import *

from embodied_cd.agents import SFTAgent


def train(args):
    # wandb log on/off
    if args.wandb_off:
        wandb.init(mode='disabled')
    else:
        wandb.init(project="Continual_Learning_by_Learning_to_Self_Correct")
        wandb.run.name = "Baselines_SFT_Action"

    output_dir = os.path.join(args.output_dir, "cd-action")
    os.makedirs(output_dir, exist_ok=True)

    # load configs and get continual phase
    config = {}
    with open(f"./configs/{args.env_name}/{args.continual_type}_train.py", "r") as f:
        code = f.read()
    exec(code, config)
    phase = config["sequences"][args.continual_seq].index(args.continual_set)
        
    # seqft 0: AppendOnly, seqft 1: SeqFT
    if (phase == 0) or (args.seqft == 0):
        # if phase 0, initialize
        tokenizer, model = prepare_model(args.pi_model_name, init=True)
        # add adapter
        model = add_adapter(model)
    else:
        # if phase > 0, prepare model for sequential finetuning
        prev_continual_set = config["sequences"][args.continual_seq][phase - 1]
        prev_output_dir = f"./results/SFT/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_continual_set}_{args.pi_model_name.split('/')[-1]}_{args.tag}"
        tokenizer, model = prepare_model(os.path.join(prev_output_dir, "cd-action"), init=False)
        model.set_adapter("default")

    # prepare dataset
    dataset, collator = prepare_dataset(
        args.env_name, 
        args.pi_model_name, 
        args.dataset_dir, 
        tokenizer, 
        dataset_size=args.dataset_size,
        dataset_type="cd-action",
        few_shot_example=None,
    )

    # set configs
    if 'instruct' in args.pi_model_name or 'Instruct' in args.pi_model_name:
        training_args = SFTConfig(
            output_dir=output_dir, 
            save_strategy="no",
            logging_strategy="epoch",
            save_total_limit=1,
            num_train_epochs=40,
            learning_rate=1.41e-05,
            per_device_train_batch_size=4,
            # to avoid the default dataset process
            # we should disable 'remove_unused_columns' & 'skip_prepare_dataset'
            remove_unused_columns=False,
            dataset_kwargs={"skip_prepare_dataset": True},
        )
    else:
        training_args = SFTConfig(
            output_dir=output_dir, 
            dataset_text_field="text",
            save_strategy="no",
            logging_strategy="epoch",
            save_total_limit=1,
            num_train_epochs=40,
            learning_rate=1.41e-05,
            per_device_train_batch_size=4,
        )

    params = training_args.__dict__
    params.update({"seqft": args.seqft})
     
    save_params(output_dir, params)

    # create trainer
    trainer = SFTTrainer(
        model,
        tokenizer=tokenizer,
        train_dataset=dataset,
        data_collator=collator,
        args=training_args
    )

    # training
    print_warn("[SFT] Model Training")
    trainer.train()

    # saving
    trainer.save_model(output_dir)


def evaluate(args):
    # load model
    tokenizer, model = prepare_model(os.path.join(args.output_dir, 'cd-action'), init=False)

    # prepare agent
    agent = SFTAgent(
        args.pi_model_name,
        None, 
        None, 
        model, 
        tokenizer, 
        args.env_name, 
        args.continual_type.split('_')[0],
        max_think_token=args.max_think_token, 
        max_think_times=None, 
        perturb=args.perturb, 
        decoding_strategy='greedy'
    )
   
    if args.continual_eval_set is None:
        flag = False
        args.continual_eval_set = args.continual_set
    else:
        flag = True

    # make env
    split_path = f"./externals/cl-alfred/embodied_split/{args.continual_type.split('_')[0]}_il/embodied_data_disjoint_rand1_cls1_task{int(args.continual_eval_set.split('_')[-1])-1}.json" # for alfred
    env = make_env(args.env_name, args.num_topk_edge, args.env_ip, args.seed, args.env_port, split_path)

    if args.evaluate == "seen" or args.evaluate == "both":
        # Evaluate on Train Environment
        env_list, task_list, available_tasks, ps_line = load_evaluation_configs(
            args.env_name, args.continual_type, args.continual_eval_set, "seen")

        print_warn(f"Evaluating in Seen Set")
        if not flag:
            save_path = args.output_dir + '/results/evaluation_seen'
        else:
            save_path = args.output_dir + f'/results/evaluation_seen_{args.continual_eval_set}'
        if args.perturb:
            save_path = save_path + '_perturb'
        save_path += '.txt'
        evaluation_pipe(
            env, 
            agent, 
            env_list, 
            task_list, 
            available_tasks,
            ps_line,
            total_episode=args.eval_episode, 
            save_path=save_path
        )
    if args.evaluate == "unseen" or args.evaluate == "both": 
        # Evaluate on Test Environment
        env_list, task_list, available_tasks, ps_line = load_evaluation_configs(
            args.env_name, args.continual_type, args.continual_eval_set, "unseen")
        
        if not flag:
            save_path = args.output_dir + '/results/evaluation_unseen'
        else:
            save_path = args.output_dir + f'/results/evaluation_unseen_{args.continual_eval_set}'
        if args.perturb:
            save_path = save_path + '_perturb'
        save_path += '.txt'
        evaluation_pipe(
            env,
            agent, 
            env_list, 
            task_list, 
            available_tasks,
            ps_line,
            total_episode=args.eval_episode, 
            save_path=save_path
        )


if __name__ == "__main__":
    args = argparse.parse_args()
    # directory settings
    if args.dataset_dir is None:
        args.dataset_dir = f"./datasets/{args.env_name}/{args.continual_type}_train/{args.continual_set}"
    if args.output_dir is None:
        args.output_dir = f"./results/SFT/{args.env_name}/{args.continual_type}_{args.continual_seq}/{args.continual_set}_{args.pi_model_name.split('/')[-1]}_{args.tag}"

    if args.train:
        train(args)
    if args.test:
        test(args)
    if args.evaluate:
        evaluate(args)
