import os
import time
import wandb
import torch
import json
import numpy as np
import tqdm

from embodied_cd.common import argparse
from embodied_cd.common.print_utils import *
from embodied_cd.common.env_utils import *
from embodied_cd.common.rag_utils import RAGPipeline
from embodied_cd.common.dataset_utils import PromptTemplate
from embodied_cd.trl.ecoc import ECoCTrainer
from embodied_cd.trl.ecoc.feedback_trainer import FeedbackTrainer
from embodied_cd.agents.ecoc import ECoCAgent

from scripts.train import *

### ==================================###
#           Phase 1: Policy             #
### ==================================###
def prepare(args, path):
    tokenizer, model = prepare_model(args.pi_model_name, init=True)
    model = add_adapter(model, name="reasoning_policy")
    model = add_adapter(model, name="planning_policy")
    if path is not None:
        model.load_adapter(path + f'/policy/reasoning_policy', adapter_name='reasoning_policy')
        model.load_adapter(path + f'/policy/planning_policy', adapter_name='planning_policy')
    return tokenizer, model

def train_policy(args):
    # wandb log on/off
    if args.wandb_off:
        wandb.init(mode="disabled")
    else:
        wandb.init(project="Continual_Learning_by_Learning_to_Self_Correct")
        wandb.run.name = "ECoC_Policy"

    tokenizer = AutoTokenizer.from_pretrained(args.pi_model_name)
    # prepare dataset
    dataset, collator = prepare_dataset(
        args.env_name, 
        args.pi_model_name, 
        args.dataset_dir, 
        tokenizer, 
        dataset_type="cd-think",
        dataset_size=args.dataset_size,
        few_shot_example=None,
        ablation=args.dataset_ablation,
    )
    del tokenizer

    # load configs and get continual phase
    config = {}
    with open(f"./configs/{args.env_name}/{args.continual_type}_train.py", "r") as f:
        code = f.read()
    exec(code, config)
    phase = config["sequences"][args.continual_seq].index(args.continual_set)
    
    if (phase == 0) or (args.seqft == 0):
        base_tokenizer, base_model = None, None
        # prepare model
        if 'S5P' in args.tag:
            pretrain_phase = config["sequences"][args.continual_seq][0]
            if args.env_name == 'virtualhome':
                path = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{pretrain_phase}_{args.pi_model_name.split('/')[-1]}_A4SeqFT_0"
            if args.env_name == 'alfred':
                path = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{pretrain_phase}_{args.pi_model_name.split('/')[-1]}_A4SeqFT_0"
            print_error(f"Initialize with {path}")
            tokenizer, model = prepare(args, path=path)
            prev_phase = None
        else:
            tokenizer, model = prepare(args, path=None)
            prev_phase = None
    else:
        if args.ablation == 4:
            prev_phase = config["sequences"][args.continual_seq][:phase][-1]
            base_tokenizer, base_model = None, None
            print_pass(f"[Loading Base Model] at {prev_phase}")
            path = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_{args.prev_tag}"
            print_pass(path)
        else:
            _prev_phase = -1
            prev_phases = config["sequences"][args.continual_seq][:phase]
            scores = [] 
            if len(prev_phases) == 1:
                prev_phase = prev_phases[0]
            else:
                for idx, prev_phase in enumerate(prev_phases):
                    if args.base_ablation == 2 or args.base_ablation == 3:
                        break

                    if idx == 0 and 'S5P' in args.tag: 
                        if args.env_name == 'virtualhome':
                            path = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_A2SeqFT_5"
                        if args.env_name == 'alfred':
                            path = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_A2SeqFT_2"
                        _prev_phase = prev_phase
                    else: 
                        # load model from previous phase
                        path = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_{args.prev_tag}"
                    tokenizer, model = prepare(args, path=path)
                    score = ECoCTrainer._setup_base_model(args.pi_model_name, tokenizer, model, dataset, args.max_think_token)
                    scores.append(score)
                    del tokenizer, model

                print(prev_phases)
                print(scores)
                if args.base_ablation == 0:
                    print_check("Base Ablation 0: Argmax")
                    index = np.argmax(scores)
                    prev_phase = prev_phases[index]
                elif args.base_ablation == 1:
                    print_check("Base Ablation 1: Argmin")
                    index = np.argmin(scores)
                    prev_phase = prev_phases[index]
                elif args.base_ablation == 2:
                    print_check("Base Ablation 1: Last")
                    prev_phase = prev_phases[-1]
                elif args.base_ablation == 3:
                    print_check("Base Ablation 3: Random")
                    prev_phase = random.choice(prev_phases)
                else:
                    raise NotImplementedError

            print_pass(f"[Loading Base Model] at {prev_phase}")
            if _prev_phase == prev_phase:
                if args.env_name == 'virtualhome':
                    path = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_A2SeqFT_5"
                    args.prev_tag = "A2SeqFT_5"
                if args.env_name == 'alfred':
                    path = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_A2SeqFT_2"
                    args.prev_tag = "A2SeqFT_2"
            else:
                path = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_{args.prev_tag}"
            print_pass(path)
            base_tokenizer, base_model = prepare(args, path=path)

        # prepare base model
        if "BaseOnly" in args.tag:
            print_error("Use base rationale generation, but initialize LoRA")
            tokenizer, model = prepare(args, path=None) 
        else:
            print_error("Initialize with base reasoning policy")
            tokenizer, model = prepare(args, path=path) 

    # set output_dir
    output_dir = args.output_dir + "/policy"

    # set configs
    params = {
        "total_epochs": 20, # few-shot 10 epoch
        "inner_epochs": 4,
        "learning_rate": 5e-5,
        "batch_size": 4,
        "max_think_token": args.max_think_token,
        "ablation": args.ablation, # 0: no ablation, 1: only generation, 2: only base correction, 3: only self correction
        "total_think": args.total_think,
        "base_continual_set": prev_phase, # last continual phase
        "base_continual_tag": args.prev_tag,
        "dataset_size": args.dataset_size,
        "dataset_ablation": args.dataset_ablation,
    }
    print(params)

    # create trainer
    trainer = ECoCTrainer(
        args.env_name,
        args.pi_model_name,
        base_tokenizer,
        base_model,
        tokenizer,
        model,
        dataset,
        output_dir,
        **params,
    )

    # save params
    save_params(output_dir, trainer.params)

    print_warn("[ECoC] Policy Training Start!")
    #model.print_trainable_parameters()
    trainer.train()

    # save model
    trainer.save_pretrained(output_dir)

### ==================================###
#           Phase 2: Feedback           #
### ==================================###
def prepare_feedback(args, path):
    tokenizer, model = prepare_model(args.pi_model_name, init=True)
    model = add_adapter(model, name="feedback_policy")
    if path is not None:
        model.load_adapter(path + f'/feedback/feedback_policy', adapter_name='feedback_policy')
    return tokenizer, model

def train_feedback(args):
    # wandb log on/off
    wandb.init(mode="disabled")
    print_pass(f"Training Feedback Policy!!")

    tokenizer = AutoTokenizer.from_pretrained(args.pi_model_name)
    # prepare dataset
    dataset, collator = prepare_dataset(
        args.env_name, 
        args.pi_model_name, 
        args.dataset_dir, 
        tokenizer, 
        dataset_type="cd-think",
        dataset_size=args.dataset_size,
        few_shot_example=None,
        ablation=args.dataset_ablation,
    )
    del tokenizer

    # prepare base model
    params = load_params(args.output_dir + '/policy')
    curr_tokenizer, curr_model = prepare(args, path=args.output_dir)
    if "base_continual_set" not in params.keys() or params["base_continual_set"] is None:
        base_tokenizer, base_model = prepare(args, path=args.output_dir) #None, None
        tokenizer, model = prepare_feedback(args, path=None)
    else:
        prev_phase, prev_tag = params["base_continual_set"], params["base_continual_tag"]
        print_warn(f"[Loading Base Model at: {prev_phase}")
        path = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_{prev_tag}"
        base_tokenizer, base_model = prepare(args, path=path)
        tokenizer, model = prepare_feedback(args, path=path)
    
    # set output_dir
    output_dir = args.output_dir + "/feedback"

    # set configs
    params = {
        "total_epochs": 10, # few-shot 10 epoch
        "inner_epochs": 4,
        "learning_rate": 5e-5,
        "batch_size": 4,
        "max_think_token": args.max_think_token,
        "total_think": args.total_think,
        "dataset_size": args.dataset_size,
        "dataset_ablation": args.dataset_ablation,
    }
    print(params)

    trainer = FeedbackTrainer(
        args.env_name,
        args.pi_model_name,
        base_tokenizer,
        base_model,
        curr_tokenizer,
        curr_model,
        tokenizer,
        model,
        dataset,
        output_dir,
        **params,
    )
        
    # save params
    save_params(output_dir, trainer.params)

    # save params
    save_params(output_dir, trainer.params)

    print_warn("[ECoC] Policy Training Start!")
    #model.print_trainable_parameters()
    trainer.train()

    # save model
    trainer.save_pretrained(output_dir)


### ==================================###
#           Evaluation & Test           #
### ==================================###
def test(args):
    base_tokenizer, base_model = None, None
    # prepare model
    tokenizer, model = prepare(args, path=args.output_dir)
    
    # prepare dataset
    dataset, collator = prepare_dataset(
        args.env_name, 
        args.pi_model_name, 
        args.dataset_dir, 
        tokenizer, 
        dataset_type="cd-think",
        dataset_size=args.dataset_size,
        few_shot_example=None,
        ablation=args.dataset_ablation,
    )
    rag_pipe = None

    # prepare agent
    agent = ECoCAgent(
        args.pi_model_name,
        base_model,
        base_tokenizer,
        model, 
        tokenizer, 
        plan_model=None,
        plan_tokenizer=None,
        rag_pipe=rag_pipe,
        env_name=args.env_name,
        cl_type=args.continual_type.split('_')[0],
        max_think_token=args.max_think_token, 
        correct=args.correct,
        no_critic=args.no_critic,
        #few_shot_example=few_shot_example,
        perturb=args.perturb,
        decoding_strategy="greedy",
    )

    thresholds = [[] for i in range(5)]
    for idx, data in enumerate(tqdm(dataset, desc="threshold")):
        instruction, state, action, history, think = data["instruction"], data["state"], data["action"], data["history"], data["think"]
        
        state = PromptTemplate.preprocess(state)
        with torch.no_grad():
            _, think_list = agent.get_think(agent.model, agent.tokenizer, instruction, state, history, sample=False)
            for j in range(len(think_list)):
                think = " ".join(think_list[:j+1])
                action_pred, prob = agent.get_action(instruction, state, history, think)
                #print(f"{j+1}-th think: {think}")
                #print_error(prob, action_pred == action, action_pred, action)
                if action_pred == action:
                    thresholds[j].append(prob)

    average, deviation = [], []
    for threshold in thresholds:
        average.append(np.mean(threshold))
        deviation.append(np.std(threshold))

    data = {
        "average": average,
        "deviation": deviation,
        "raw": thresholds,
    }
    with open(args.output_dir + '/thresholds.json', 'w') as f:
        json.dump(data, f)


def evaluate(args):
    # prepare base model
    params = load_params(args.output_dir + '/policy')
    """
    if "base_continual_set" not in params.keys() or params["base_continual_set"] is None:
        base_tokenizer, base_model = prepare(args, path=args.output_dir) #None, None
    else:
        prev_phase, prev_tag = params["base_continual_set"], params["base_continual_tag"]
        print_warn(f"[Loading Base Model at: {prev_phase}")
        path = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{prev_phase}_{args.pi_model_name.split('/')[-1]}_{prev_tag}"
        base_tokenizer, base_model = prepare(args, path=path)
    """
    base_tokenizer, base_model = None, None

    # prepare model
    tokenizer, model = prepare(args, path=args.output_dir)

    # prepare rag pipeline
    if args.correct:
        """
        rew_tokenizer = AutoTokenizer.from_pretrained(args.output_dir + "/critic")
        rew_model = AutoModel.from_pretrained(
            args.output_dir + "/critic", device_map="auto")
        dataset, collator = prepare_dataset(
            args.env_name, 
            args.pi_model_name, 
            args.dataset_dir, 
            tokenizer, 
            dataset_percent=1.0, 
            dataset_size=args.dataset_size,
            shuffle=True,
            ablation=args.dataset_ablation,
        )
        rag_pipe = RAGPipeline(args.env_name, rew_tokenizer, rew_model, dataset)
        """
        rag_pipe = None
        model = add_adapter(model, name="feedback_policy")
        model.load_adapter(args.output_dir + f'/feedback/feedback_policy', adapter_name='feedback_policy')
    else:
        rag_pipe = None

    # read threshold for test-time control
    test_time_thresh = None
    if args.test_time_control: 
        with open(args.output_dir + '/thresholds.json', 'r') as f:
            test_time_thresh = json.load(f)

    # prepare agent
    agent = ECoCAgent(
        args.pi_model_name,
        base_model,
        base_tokenizer,
        model, 
        tokenizer, 
        plan_model=None,
        plan_tokenizer=None,
        rag_pipe=rag_pipe,
        env_name=args.env_name,
        cl_type=args.continual_type.split('_')[0],
        max_think_token=args.max_think_token, 
        total_think=args.total_think,
        correct=args.correct,
        no_critic=args.no_critic,
        #few_shot_example=few_shot_example,
        perturb=args.perturb,
        decoding_strategy="greedy",
        test_time_thresh=test_time_thresh,
    )
    
    if args.continual_eval_set is None:
        flag = False
        args.continual_eval_set = args.continual_set
    else:
        flag = True

    # make env
    split_path = f"./externals/cl-alfred/embodied_split/{args.continual_type.split('_')[0]}_il/embodied_data_disjoint_rand1_cls1_task{int(args.continual_eval_set.split('_')[-1])-1}.json" # for alfred
    env = make_env(args.env_name, args.num_topk_edge, args.env_ip, args.seed, args.env_port, split_path)
    
    if args.evaluate == "seen" or args.evaluate == "both":
        agent.evaluate = 'seen'
        # Evaluate on Train Environment
        eval_tag = "seen"
        if "S5" in args.tag: # 5-shot dataset
            eval_tag = "shot5_seen"
        print_error(f"Eval Tag {eval_tag}")
        env_list, task_list, available_tasks, ps_line = load_evaluation_configs(
            args.env_name, args.continual_type, args.continual_eval_set, eval_tag)
        
        print_warn(f"Evaluating in Seen Set")
        if not flag:
            save_path = args.output_dir + '/results/evaluation_seen'
        else:
            save_path = args.output_dir + f'/results/evaluation_seen_{args.continual_eval_set}'
        if args.perturb:
            save_path = save_path + '_perturb'
        if args.test_time_control:
            save_path = save_path + '_control'
        if args.total_think != 5:
            save_path = save_path + f'_think{args.total_think}'
        if args.correct:
            save_path = save_path + '_correct'
        save_path += '.txt'
        evaluation_pipe(
            env, 
            agent, 
            env_list, 
            task_list, 
            available_tasks,
            ps_line,
            total_episode=args.eval_episode, 
            save_path=save_path
        )
    if args.evaluate == "unseen" or args.evaluate == "both": 
        agent.evaluate = 'unseen'
        # Evaluate on Test Environment
        eval_tag = "unseen"
        if "S5" in args.tag: # 5-shot dataset
            eval_tag = "shot5_unseen"
        print_error(f"Eval Tag {eval_tag}")
        env_list, task_list, available_tasks, ps_line = load_evaluation_configs(
            args.env_name, args.continual_type, args.continual_eval_set, eval_tag)
        
        if not flag:
            save_path = args.output_dir + '/results/evaluation_unseen'
        else:
            save_path = args.output_dir + f'/results/evaluation_unseen_{args.continual_eval_set}'
        if args.perturb:
            save_path = save_path + '_perturb'
        if args.test_time_control:
            save_path = save_path + '_control'
        if args.total_think != 5:
            save_path = save_path + f'_think{args.total_think}'
        if args.correct:
            save_path = save_path + '_correct'
        save_path += '.txt'
        evaluation_pipe(
            env, 
            agent, 
            env_list, 
            task_list, 
            available_tasks,
            ps_line,
            total_episode=args.eval_episode, 
            save_path=save_path
        )


if __name__ == "__main__":
    args = argparse.parse_args()
    args.max_think_token = 40
    # directory settings
    if args.dataset_dir is None:
        args.dataset_dir = f"./datasets/{args.env_name}/{args.continual_type}_train/{args.continual_set}"
    if args.output_dir is None:
        args.output_dir = f"./results/ECoC/{args.env_name}/{args.continual_type}_{args.continual_seq}/{args.continual_set}_{args.pi_model_name.split('/')[-1]}_{args.tag}"

    if args.prev_tag is None:
        args.prev_tag = args.tag

    if args.train:
        if args.phase == 0:
            train_critic(args)
        elif args.phase == 1:
            train_policy(args)
        elif args.phase == 2:
            train_feedback(args)
        else:
            raise NotImplementedError
    if args.test:
        test(args)
    if args.evaluate:
        evaluate(args)
