from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
from peft import PeftModel
from llm2vec import LLM2Vec

# from data_selection.monkey_patch import apply_front_layer_patch
from features import (
    # sr_compute_reps,
    # ps_wanda_compute_reps,
    # sps_compute_reps,
    # sps_compute_feature,
    sae_compute_feature,
    bge_compute_feature,
    rds_compute_feature,
    llm2vec_compute_feature,
    nosae_compute_feature,
)
import torch
from functools import partial
from collate_functions import (
    tokenize_collate_func,
    concat_turns_collate_func,
)
from util_funcs import (
    list_extend_all_heldout_data_feature_agg,
    weighted_topk_sae_feature_agg_func,
    concat_extend_all_heldout_data_feature_agg,
    # weighted_sae_feature_sim,
    # batch_samples_and_tasks_sim_func,
)
from sae import Sae
import os


class RDSConfig:
    def __init__(
        self,
        model_name_or_path=None,
        dialog_format="",
        max_length=4096,
    ):
        """
        1. model
        2. collate_func
        3. heldout_data_feature_agg_func
        4. task_heldout_data_feature_agg_func
        5. compute_feature
        """
        model = (
            AutoModel.from_pretrained(
                model_name_or_path,
                attn_implementation="flash_attention_2",
                torch_dtype="auto",
            )
            .cuda()
            .eval()
        )

        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        if "llama" in model_name_or_path.lower():
            mutual_add_eos_token = True
        else:
            mutual_add_eos_token = False

        print(f"mutual_add_eos_token: {mutual_add_eos_token}")
        self.collate_func = partial(
            concat_turns_collate_func,
            dialog_format=dialog_format,
            tokenizer=tokenizer,
            padding="longest",
            max_length=max_length,
            mutual_add_eos_token=mutual_add_eos_token,
        )

        self.heldout_data_feature_agg_func = concat_extend_all_heldout_data_feature_agg

        self.task_heldout_data_feature_agg_func = lambda x: x

        self.compute_feature = partial(
            rds_compute_feature,
            model=model,
        )


class BGEConfig:
    def __init__(
        self,
        model_name_or_path=None,
        dialog_format="",
        max_length=4096,
    ):
        """
        1. model
        2. collate_func
        3. heldout_data_feature_agg_func
        4. task_heldout_data_feature_agg_func
        5. compute_feature
        """
        model = (
            AutoModel.from_pretrained(
                model_name_or_path,
                torch_dtype=torch.float16,
            )
            .cuda()
            .eval()
        )
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        self.collate_func = partial(
            concat_turns_collate_func,
            dialog_format=dialog_format,
            tokenizer=tokenizer,
            padding="max_length",
            max_length=max_length,
        )

        self.heldout_data_feature_agg_func = concat_extend_all_heldout_data_feature_agg

        self.task_heldout_data_feature_agg_func = lambda x: x

        self.compute_feature = partial(
            bge_compute_feature,
            model=model,
        )


class LLM2VecConfig:
    def __init__(
        self,
        model_name_or_path=None,
        dialog_format="",
        max_length=4096,
    ):
        """
        1. model
        2. collate_func
        3. heldout_data_feature_agg_func
        4. task_heldout_data_feature_agg_func
        5. compute_feature
        """
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)

        model = (
            AutoModel.from_pretrained(
                model_name_or_path,
                trust_remote_code=True,
                config=config,
                torch_dtype=torch.bfloat16,
            )
            .cuda()
            .eval()
        )
        model = PeftModel.from_pretrained(model, model_name_or_path)
        # Wrapper for encoding and pooling operations
        l2v = LLM2Vec(model, tokenizer, pooling_mode="mean", max_length=max_length)

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        self.collate_func = partial(
            concat_turns_collate_func,
            dialog_format=dialog_format,
            tokenizer=tokenizer,
            padding="max_length",
            max_length=max_length,
            tokenize=False,
        )

        self.heldout_data_feature_agg_func = concat_extend_all_heldout_data_feature_agg

        self.task_heldout_data_feature_agg_func = lambda x: x

        self.compute_feature = partial(
            llm2vec_compute_feature,
            model=l2v,
        )


class SAEConfig:
    def __init__(
        self,
        model_name_or_path=None,
        sae_model_name_or_paths=[],
        max_length=4096,
        chat_template=None,
        dialog_format="",
        topk_for_token_agg=-1,
        avg_level="sample",
    ):
        """
        1. model
        2. collate_func
        3. heldout_data_feature_agg_func
        4. task_heldout_data_feature_agg_func
        5. compute_feature
        """
        # model and tokenizer
        model = (
            AutoModelForCausalLM.from_pretrained(
                model_name_or_path,
                attn_implementation="flash_attention_2",
                torch_dtype="auto",
            )
            .cuda()
            .eval()
        )

        self.model = model
        tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, use_fast=True, add_bos_token=True
        )

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        sae_models = dict()
        for sae_model_name_or_path in sae_model_name_or_paths:
            name = os.path.basename(sae_model_name_or_path)

            sae_model = (
                Sae.load_from_disk(sae_model_name_or_path)
                .to(device=model.device, dtype=model.dtype)
                .eval()
            )
            sae_models[name] = sae_model

        self.collate_func = partial(
            tokenize_collate_func,
            tokenizer=tokenizer,
            max_length=max_length,
            chat_template=chat_template,
            dialog_format=dialog_format,
        )

        self.heldout_data_feature_agg_func = list_extend_all_heldout_data_feature_agg
        # self.task_heldout_data_feature_agg_func = lambda task_heldout_feature: partial(
        #     weighted_topk_sae_feature_agg_func, k=topk_for_token_agg
        # )(sum(task_heldout_feature.values(), []))
        self.task_heldout_data_feature_agg_func = lambda task_heldout_feature: {
            task: partial(
                weighted_topk_sae_feature_agg_func,
                k=topk_for_token_agg,
                avg_level="sample",
            )(task_heldout_feature[task])
            for task in task_heldout_feature
        }
        # self.task_heldout_data_feature_agg_func = lambda x: x

        self.compute_feature = partial(
            sae_compute_feature,
            model=model,
            sae_models=sae_models,
            k=topk_for_token_agg,
            avg_level=avg_level,
        )


class NoSAEConfig(SAEConfig):
    def __init__(
        self,
        *args,
        topk_for_token_agg=-1,
        sae_model_name_or_paths=[],
        avg_level="sample",
        **kwargs,
    ):
        super().__init__(
            *args,
            **kwargs,
        )

        self.compute_feature = partial(
            nosae_compute_feature,
            model=self.model,
            layers=sae_model_name_or_paths,
            k=topk_for_token_agg,
            avg_level=avg_level,
        )


def get_feature_config(
    method=None,
    model_name_or_path=None,
    sae_model_name_or_paths=[],
    max_length=4096,
    chat_template=None,
    dialog_format="",
    topk_for_token_agg=-1,
    avg_level="sample",
    **kwargs,
):
    print(f"method: {method}")
    if method == "sae":
        return SAEConfig(
            model_name_or_path=model_name_or_path,
            sae_model_name_or_paths=sae_model_name_or_paths,
            max_length=max_length,
            chat_template=chat_template,
            dialog_format=dialog_format,
            topk_for_token_agg=topk_for_token_agg,
            avg_level=avg_level,
        )
    elif method == "nosae":
        return NoSAEConfig(
            model_name_or_path=model_name_or_path,
            sae_model_name_or_paths=sae_model_name_or_paths,
            max_length=max_length,
            chat_template=chat_template,
            dialog_format=dialog_format,
            topk_for_token_agg=topk_for_token_agg,
            avg_level=avg_level,
        )
    elif method == "bge":
        return BGEConfig(
            model_name_or_path=model_name_or_path,
            dialog_format=dialog_format,
            max_length=max_length,
        )
    elif method == "llm2vec":
        return LLM2VecConfig(
            model_name_or_path=model_name_or_path,
            dialog_format=dialog_format,
            max_length=max_length,
        )
    elif method == "rds":
        return RDSConfig(
            model_name_or_path=model_name_or_path,
            dialog_format=dialog_format,
            max_length=max_length,
        )


# if method == "sr":
#         return SRConfig(
#             model_name_or_path=model_name_or_path,
#             max_length=max_length,
#             front_layer_layer_num=front_layer_layer_num,
#             chat_template=chat_template,
#             dialog_format=dialog_format,
#             sim_aggregation=sim_aggregation,
#         )
#     elif method == "ps_wanda":
#         return PSWandaConfig(
#             model_name_or_path=model_name_or_path,
#             max_length=max_length,
#             front_layer_layer_num=front_layer_layer_num,
#             chat_template=chat_template,
#             dialog_format=dialog_format,
#             sim_aggregation=sim_aggregation,
#             prune_ratio=prune_ratio,
#             top_layer_num=top_layer_num,
#         )
#     elif
