import os
import json
import torch
import yaml
from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel, AutoModelForCausalLM
from trl import DataCollatorForCompletionOnlyLM
from peft import LoraConfig, PeftModel, get_peft_model

from embodied_cd.common import dataset_utils
from embodied_cd.common.print_utils import *
from embodied_cd.trl.models import * 

_PI_MODEL_ALIASES = {
    'gpt2': GPT2LMHeadModel,
    'Qwen2.5': AutoModelForCausalLM,
    'Llama': AutoModelForCausalLM, 
}

_VF_MODEL_ALIASES = {
    'gpt2': GPT2HeadWithValueModel,
    'Qwen2.5': Qwen2HeadWithValueModel,
    'Llama': LlamaHeadWithValueModel,
}


def prepare_dataset(env_name, model_name, dataset_dir, tokenizer, dataset_type="default", few_shot_example=None, dataset_percent=1.0, dataset_size=0.25, shuffle=False, ablation=False):
    if dataset_size is None:
        if env_name == 'virtualhome':
            dataset_size = 0.125
        if env_name == 'alfred':
            dataset_size = 0.25

    # load dataset
    if env_name == "virtualhome":
        if not ablation:
            if dataset_size == 1.0:
                dataset = dataset_utils.VirtualHomeDataset.load(dataset_dir+f'/full_dataset.jsonl', dataset_percent, shuffle)
            else:
                dataset = dataset_utils.VirtualHomeDataset.load(dataset_dir+f'/part_{dataset_size}_dataset.jsonl', dataset_percent, shuffle)
        print_warn(f"[dataset] Loading total {len(dataset)} dataset")
        dataset.set_template(dataset_type, few_shot_example)
    elif env_name == "alfred":
        if dataset_size == 1.0:
            dataset = dataset_utils.AlfredDataset.load(dataset_dir+f'/full_dataset.jsonl', dataset_percent, shuffle)
        else:
            dataset = dataset_utils.AlfredDataset.load(dataset_dir+f'/part_{dataset_size}_dataset.jsonl', dataset_percent, shuffle)
        print_warn(f"[dataset] Loading total {len(dataset)} dataset")
        dataset.set_template(dataset_type, few_shot_example)
    else:
        raise NotImplementedError

    if 'instruct' in model_name or 'Instruct' in model_name:
        # if instruct model
        print_warn("[Dataset] Using Chat Template")
        dataset = dataset.as_chat(tokenizer)
        def collate_fn(examples):
            texts = [example["text"] for example in examples]
            batch = tokenizer(texts, return_tensors="pt", padding=True)
            labels = batch['input_ids'].clone()
            for idx, label in enumerate(labels):
                # we mask the query part in the loss computation
                label[:len(examples[idx]['query_ids'][0])] = -100
            batch['labels'] = labels
            return batch
        collator = collate_fn
    else: 
        # else completion
        print_warn("[Dataset] Using Completion Template")
        dataset = dataset.as_completion(tokenizer)
        collator = DataCollatorForCompletionOnlyLM(
            tokenizer=tokenizer, 
            mlm=False,
            instruction_template=dataset_utils.VirtualHomeDataset.additional_special_tokens["query"],
            response_template=dataset_utils.VirtualHomeDataset.additional_special_tokens["response"])
    return dataset, collator


def prepare_model(model_name, vf=False, init=True, mlm=False, config=dict()):
    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if init:
        if tokenizer.pad_token is None:
            print_warn("[Tokenizer] has no pad token")
            tokenizer.pad_token = tokenizer.eos_token
            #print_warn("[Tokenizer] Add pad token")
            # if the pad token and eos token is same, the huggingface trainer regard eos token as pad token and do not generate eos token
            # so we add new pad token
            #tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'})
        if mlm and (tokenizer.mask_token is None):
            print_error("[Tokenizer] Add mask token")
            tokenizer.add_special_tokens({'mask_token': '[MASK]'})
            print(tokenizer.mask_token_id)
        tokenizer.padding_side = 'right'
        #if not ('instruct' in model_name or 'Instruct' in model_name):
        #    print_warn("[Tokenizer] Add completion tokens")
            # add tokenizer with completion tokens
        #    tokenizer.add_special_tokens(
        #        {"additional_special_tokens": list(dataset_utils.VirtualHomeDataset.additional_special_tokens.values())})

    model_aliases = _PI_MODEL_ALIASES if not vf else _VF_MODEL_ALIASES

    # load model
    for key, cls in model_aliases.items():
        if key in model_name:
            print_warn(f"[Loading: Model] from {model_name}")
            model = cls.from_pretrained(model_name, **config, device_map='auto')
            if not init and vf:
                print_warn(f"[Loading: ValueHead] from vl_head.pth")
                model.vl_head.load_state_dict(torch.load(model_name + '/vl_head.pth'))
            if init and len(tokenizer) > model.vocab_size:
                print_warn(f"[Loading: Embedding] Resize token embedding size from {model.vocab_size} to {len(tokenizer)}")
                # resizing the embedding layer with tokenizer length, considering additional tokens
                model.resize_token_embeddings(len(tokenizer))
    return tokenizer, model


def add_adapter(model, name="default"):    
    # config lora
    lora_config = LoraConfig(
        lora_alpha=32,
        lora_dropout=0.1,
        r=16, # lora rank
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["v_proj", "q_proj"],
    )

    # add adapter
    print_warn("[Model] Adding Lora Adapter")
    if not isinstance(model, PeftModel):
        model = get_peft_model(model, lora_config, name)
    else:
        model.add_adapter(name, lora_config)
    model.print_trainable_parameters()
    return model


def print_model_params(model):
    for name, params in model.named_parameters():
        print(name, params, params.requires_grad)


def print_adapter_params(model):
    for name, params in model.named_parameters():
        if "lora" in name:
            print(name, params, params.requires_grad)


def load_evaluation_configs(env_name, _type, _set, _eval): # behavior-il / environment-il
    config_path = f"./configs/{env_name}/{_type}_{_eval}.py"

    config = {}
    with open(config_path, 'r') as f:
        code = f.read()
    exec(code, config)
    
    # loading pending steps
    if "seen" in _eval:
        config_path = f"./datasets/{env_name}/{_type}_seen/{_set}/expert_traj.jsonl"
    if "unseen" in _eval:
        config_path = f"./datasets/{env_name}/{_type}_unseen/{_set}/expert_traj.jsonl"
    with open(config_path, 'r') as f:
        for line in f:
            ps_line = json.loads(line)
            break

    if env_name == "virtualhome": 
        return list(config['rooms'][_set].keys()), list(config['tasks'][_set].keys()), config['available_tasks'], ps_line
    if env_name == "alfred":
        return config['rooms'][_set], config['tasks'][_set], None, ps_line


def save_params(output_dir, params):
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, 'params.yml'), 'w+') as f:
        yaml.dump(params, f, default_flow_style=False)


def load_params(output_dir):
    with open(os.path.join(output_dir, 'params.yml'), 'r') as f:
        params = yaml.safe_load(f)
    return params
