#!/usr/bin/env python3

import hydra
import importlib
import os
import torch
import transformers
import argparse
from pathlib import Path
import json
import copy
import codecs

import logging

log = logging.getLogger()

from lm_polygraph.utils.manager import UEManager
from dataset import Dataset
from lm_polygraph.utils.model import WhiteboxModel, create_ensemble
from lm_polygraph.utils.processor import Logger
from lm_polygraph.generation_metrics.accuracy import AccuracyMetric
from lm_polygraph.generation_metrics.bart_score import BartScoreSeqMetric
from lm_polygraph.generation_metrics.rouge import RougeMetric
from lm_polygraph.generation_metrics.bert_score import BertScoreMetric
from lm_polygraph.generation_metrics.sbert import SbertMetric
from lm_polygraph.generation_metrics.aggregated_metric import AggregatedMetric
from lm_polygraph.generation_metrics.alignscore import AlignScore
from lm_polygraph.generation_metrics.comet import Comet
from lm_polygraph.utils.openai_chat import OpenAIChat
from lm_polygraph.generation_metrics.openai_fact_check import OpenAIFactCheck
from lm_polygraph.estimators import *
from lm_polygraph.estimators.ensemble_token_measures import all_token_estimators
from lm_polygraph.estimators.ensemble_sequence_measures import all_ep_estimators, all_pe_estimators
from lm_polygraph.estimators.ensemble_token_measures import *
from lm_polygraph.ue_metrics import *
from lm_polygraph.utils.generation_parameters import GenerationParameters

from rauq import RAUQ

from supervised_baselines.lookback_lens import LookBackLens
from supervised_baselines.saplma import SAPLMA
from supervised_baselines.mind import MIND
from supervised_baselines.sheeps import LayerSheeps, Sheeps
from supervised_baselines.factoscope import LLMFactoscopeAll

from unsupervised_baselines.focus import Focus, FocusClaim
from unsupervised_baselines.simple_focus import SimpleFocus
from unsupervised_baselines.grads_methods import IntegratedGradients
from unsupervised_baselines.eigenscore import EigenScore
from unsupervised_baselines.luq import LUQ
from unsupervised_baselines.llm_check_attention import LLMCheckAttention

hydra_config = Path(os.environ["HYDRA_CONFIG"])


@hydra.main(
    version_base=None,
    config_path=str(hydra_config.parent),
    config_name=str(hydra_config.name),
)
def main(args):
    save_path = os.getcwd()
    log.info(f"Main directory: {save_path}")
    os.chdir(hydra.utils.get_original_cwd())

    save_path = args.save_path if "save_path" in args else save_path

    if args.seed is None or len(args.seed) == 0:
        args.seed = [1]

    model_kwargs = get_model_kwargs(args)

    cache_kwargs = {}
    if os.environ.get('HF_DATASETS_OFFLINE', '').strip() == '1':
        cache_kwargs = {'cache_dir': args.cache_path}

    for seed in args.seed:
        log.info("=" * 100)
        log.info(f"SEED: {seed}")

        log.info(f"Loading model {args.model.path}...")
        transformers.set_seed(seed)
        
        
        if "gemma-3" in args.model.path:
            from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
            
            model_path = args.model.path
            model_type = "CausalLM"
            generation_params = GenerationParameters(**getattr(args, "generation_params", {}))
            model = Gemma3ForConditionalGeneration.from_pretrained(model_path, trust_remote_code=True, device_map=args.model.device_map).eval()
            tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left", add_bos_token=True)
            
            model = WhiteboxModel(
                model, tokenizer, model_path, model_type, generation_params
            )
        else:
            model = WhiteboxModel.from_pretrained(
                args.model.path,
                getattr(args, "generation_params", {}),
                device_map=args.model.device_map,
                add_bos_token=getattr(args.model, "add_bos_token", True),
                **cache_kwargs,
                **model_kwargs,
            )

        if args.model.ensemble:
            # Only MC-ensembles for now
            log.info(f"Creating ensemble...")
            ensemble_model = create_ensemble(model_paths=[args.model.path],
                                             mc=True,
                                             seed=args.seed[0],
                                             ensembling_mode=args.model.ensembling_mode,
                                             mc_seeds=args.model.mc_seeds,
                                             dropout_rate=float(args.model.dropout_rate),
                                             **cache_kwargs,
                                             **model_kwargs
                                             )
        else:
            ensemble_model = None

        log.info("Done with loading model.")

        log.info(f"Loading dataset {args.dataset}...")
        dataset = Dataset.load(
            args.dataset,
            args.text_column,
            args.label_column,
            batch_size=args.batch_size,
            prompt=args.prompt,
            description=getattr(args, "description", ""),
            mmlu_max_subject_size=getattr(args, "mmlu_max_subject_size", 100),
            n_shot=getattr(args, "n_shot", 5),
            few_shot_split=getattr(args, "few_shot_split", "train"),
            split=args.eval_split,
            load_from_disk=args.load_from_disk,
            max_new_tokens=getattr(args, f"max_new_tokens", 100),
            **cache_kwargs
        )

        estimators = []
        estimators += get_ue_methods(args, model)
        density_based_ue_methods = get_density_based_ue_methods(args, model.model_type)
        estimators += density_based_ue_methods

        train_dataset = None
        background_train_dataset = None
        if any([not getattr(method, "is_fitted", True) for method in estimators]) and (not getattr(args, "kfolds", False)):
            if (args.train_dataset is not None) and (
                    args.train_dataset != args.dataset
            ):
                train_dataset = Dataset.load(
                    args.train_dataset,
                    args.text_column,
                    args.label_column,
                    batch_size=args.batch_size,
                    prompt=args.prompt,
                    description=getattr(args, "description", ""),
                    mmlu_max_subject_size=getattr(args, "mmlu_max_subject_size", 100),
                    n_shot=getattr(args, "n_shot", 5),
                    few_shot_split=getattr(args, "few_shot_split", "train"),
                    split=args.train_split,
                    size=10_000,
                    load_from_disk=args.load_from_disk,
                    max_new_tokens=getattr(args, f"max_new_tokens", 100),
                    **cache_kwargs
                )
            elif args.train_test_split:
                X_train, X_test, X_raw_train, X_raw_test, y_train, y_test, max_new_tokens_train, max_new_tokens_test = dataset.train_test_split(
                    test_size=args.test_split_size, seed=seed, split=args.eval_split
                )
                train_dataset = Dataset(
                    x=X_train, raw_x=X_raw_train, y=y_train, max_new_tokens=getattr(args, "max_new_tokens", 100), batch_size=args.batch_size
                )
            else:
                train_dataset = Dataset.load(
                    args.dataset,
                    args.text_column,
                    args.label_column,
                    batch_size=args.batch_size,
                    prompt=args.prompt,
                    description=getattr(args, "description", ""),
                    mmlu_max_subject_size=getattr(args, "mmlu_max_subject_size", 100),
                    n_shot=getattr(args, "n_shot", 5),
                    few_shot_split=getattr(args, "few_shot_split", "train"),
                    split=args.train_split,
                    size=10_000,
                    load_from_disk=args.load_from_disk,
                    max_new_tokens=getattr(args, f"max_new_tokens", 100),
                    **cache_kwargs
                )
            if args.subsample_train_dataset != -1:
                train_dataset.subsample(args.subsample_train_dataset, seed=seed)
                
            if getattr(args, "train_dataset_1", False):

                k_ds = 1
                train_dataset = None
                while getattr(args, f"train_dataset_{k_ds}", False):
                    train_dataset_k = Dataset.load(
                        getattr(args, f"train_dataset_{k_ds}"),
                        getattr(args, f"train_text_column_{k_ds}"),
                        getattr(args, f"train_label_column_{k_ds}"),
                        batch_size=args.batch_size,
                        prompt=codecs.decode(getattr(args, f"train_prompt_{k_ds}"), 'unicode_escape'),
                        description=codecs.decode(getattr(args, f"train_description_{k_ds}", ""), 'unicode_escape'),
                        mmlu_max_subject_size=getattr(args, "mmlu_max_subject_size", 100),
                        n_shot=getattr(args, f"train_n_shot_{k_ds}", 5),
                        few_shot_split=getattr(args, f"few_shot_split_{k_ds}", "train"),
                        split=getattr(args, f"train_split_{k_ds}", "train"),
                        max_new_tokens=getattr(args, f"max_new_tokens_{k_ds}", 100),
                        size=10_000,
                        load_from_disk=args.load_from_disk,
                        **cache_kwargs
                    )
                    k_ds += 1
                    if args.subsample_train_dataset != -1:
                        train_dataset_k.subsample(args.subsample_train_dataset, seed=seed)
                        
                    if getattr(args, "multiref", False):
                        if isinstance(train_dataset_k.y[0], list):
                            pass
                        else:
                            train_dataset_k.y = [[y] for y in train_dataset_k.y]
                    else:
                        if isinstance(train_dataset_k.y[0], list):
                            train_dataset_k.y = [y[0] for y in train_dataset_k.y]
                        else:
                            pass
                            
                    if train_dataset is None:
                        train_dataset = train_dataset_k
                    else:
                        train_dataset.concat(train_dataset_k.x, train_dataset_k.raw_x, train_dataset_k.y, train_dataset_k.max_new_tokens)

        if any([not getattr(method, "is_fitted", False) for method in estimators]):
            try:
                background_train_dataset = Dataset.load(
                    args.background_train_dataset,
                    args.background_train_dataset_text_column,
                    args.background_train_dataset_label_column,
                    batch_size=args.batch_size,
                    data_files=args.background_train_dataset_data_files,
                    split="train",
                    size=100_000,
                    load_from_disk=args.background_load_from_disk,
                    **cache_kwargs
                )
                if args.subsample_background_train_dataset != -1:
                    background_train_dataset.subsample(
                        args.subsample_background_train_dataset, seed=seed
                    )
            except:
                pass            

        if args.subsample_eval_dataset != -1:
            dataset.subsample(args.subsample_eval_dataset, seed=seed)

        log.info("Done with loading data.")

        generation_metrics = get_generation_metrics(args)
        ue_metrics = get_ue_metrics(args)
        
        if getattr(args, "cherrypick", False):
            x = []
            y = []
            for x_i, y_i in zip(dataset.x, dataset.y):
                if "that is not an official language of the U.S." in x_i:
                    x.append(x_i)
                    y.append(y_i)
            dataset.x = x
            dataset.y = y

        if getattr(args, "crossval_claim_ue", False):
            
            train_idx, test_idx = dataset.kfolds(seed=seed, n_folds=getattr(args, "kfolds", 2))
            for f_idx in range(len(train_idx)):
                train_ds = copy.deepcopy(dataset).select(train_idx[f_idx])
                eval_ds = copy.deepcopy(dataset).select(test_idx[f_idx])

                estimators = []
                estimators += get_ue_methods(args, model)
                density_based_ue_methods = get_density_based_ue_methods(args, model.model_type)
                estimators += density_based_ue_methods
        
                man = UEManager(
                    eval_ds,
                    model,
                    estimators,
                    generation_metrics,
                    ue_metrics,
                    [
                        Logger(),
                    ],
                    deberta_batch_size=getattr(args, 'deberta_batch_size', 10),
                    train_data=train_ds,
                    ignore_exceptions=args.ignore_exceptions,
                    background_train_data=background_train_dataset,
                    max_new_tokens=args.max_new_tokens,
                    ensemble_model=ensemble_model
                )
                man()
                man.save(save_path + f"/ue_manager_seed{seed}_fold{f_idx}")
        else:
            man = UEManager(
                dataset,
                model,
                estimators,
                generation_metrics,
                ue_metrics,
                [
                    Logger(),
                ],
                deberta_batch_size=getattr(args, 'deberta_batch_size', 10),
                train_data=train_dataset,
                ignore_exceptions=args.ignore_exceptions,
                background_train_data=background_train_dataset,
                max_new_tokens=args.max_new_tokens,
                ensemble_model=ensemble_model
            )
    
            man()
    
            man.save(save_path + f"/ue_manager_seed{seed}")

def get_ue_metrics(args):
    ue_metrics = [
        # ReversedPairsProportion(),
        PredictionRejectionArea(),
        # RiskCoverageCurveAUC(),
    ]
    if getattr(args, "use_claim_ue", False) or getattr(args, "train_claim_pi", False):
        ue_metrics += [
            ROCAUC(),
            PRAUC(),
        ]
    return ue_metrics


def get_density_based_ue_methods(args, model_type):
    estimators = []
    if args.use_density_based_ue:
        if getattr(args, 'parameters_path', False):
            parameters_path = args.parameters_path
        else:
            dataset_name = args.dataset if isinstance(args.dataset, str) else '_'.join(args.dataset)
            dataset_name = dataset_name.split("/")[-1].split(".")[0]
            model_name = args.model.path.split("/")[-1]
            parameters_path = f"{args.cache_path}/density_stats/{dataset_name}/{model_name}"
        
        if model_type == "Seq2SeqLM":
            estimators += [
                MahalanobisDistanceSeq("encoder", parameters_path=parameters_path),
                MahalanobisDistanceSeq("decoder", parameters_path=parameters_path),
                RelativeMahalanobisDistanceSeq(
                    "encoder", parameters_path=parameters_path
                ),
                RelativeMahalanobisDistanceSeq(
                    "decoder", parameters_path=parameters_path
                ),
                RDESeq("encoder", parameters_path=parameters_path),
                RDESeq("decoder", parameters_path=parameters_path),
                PPLMDSeq("encoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("encoder", md_type="RMD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="RMD", parameters_path=parameters_path),
            ]
        else:
            estimators += [
                MahalanobisDistanceSeq("decoder", parameters_path=parameters_path),
                RelativeMahalanobisDistanceSeq(
                    "decoder", parameters_path=parameters_path
                ),
                RDESeq("decoder", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="RMD", parameters_path=parameters_path),
            ]
    return estimators


def get_ue_methods(args, model):
    estimators = []

    metric_name = getattr(args, "target_train_metric", "AlignScore")
    use_accuracy = False
    use_alignscore = True
    use_comet = False
    if getattr(args, "is_ood_exps", False):
        use_alignscore = True
        metric_name = "AlignScoreMean"
        metric = AlignScore(batch_size=4, return_mean=True)
    elif metric_name == "Accuracy":
        use_accuracy = True
        use_alignscore = False
        metric = AccuracyMetric(
            target_ignore_regex = getattr(args, "target_ignore_regex", None),
            output_ignore_regex = getattr(args, "output_ignore_regex", None),
            normalize = getattr(args, "normalize", False),
        )
    elif metric_name == "Comet":
        use_comet = True
        use_alignscore = False
        metric = Comet(source_ignore_regex = getattr(args, "source_ignore_regex", None))
    elif metric_name == "AlignScore":
        use_alignscore = True
        metric = AlignScore(batch_size=4)
    elif metric_name == "AlignScoreMean":
        use_alignscore = True
        metric = AlignScore(batch_size=4, return_mean=True)
    elif metric_name == "AlignScoreInv":
        use_alignscore = True
        metric = AlignScore(batch_size=4, target_is_claims=False)

    aggregated = getattr(args, "multiref", False)
    if aggregated:
        metric = AggregatedMetric(base_metric=metric)
        

    dataset_name = args.dataset if isinstance(args.dataset, str) else '_'.join(args.dataset)
    dataset_name = dataset_name.split("/")[-1].split(".")[0]
    model_name = args.model.path.split("/")[-1]
    parameters_path = f"{args.save_path}/tad_stats/"
    
    if "gemma-3" in args.model.path:
        hidden_layers = list(range(model.model.config.text_config.num_hidden_layers))[:-1] + [-1]
        mid_layer = int(model.model.config.text_config.num_hidden_layers//2)
    else:
        hidden_layers = list(range(model.model.config.num_hidden_layers))[:-1] + [-1]
        mid_layer = int(model.model.config.num_hidden_layers//2)
                      
    if args.use_seq_ue:
        estimators += [
            MaximumSequenceProbability(),
            Perplexity(),
            MeanTokenEntropy(),
            
            # Focus(model_name=args.model.path, negative=True, reccurent=True),
            # Focus(model_name=args.model.path, gamma=0, negative=True, reccurent=True),
            
            # Focus(model_name=args.model.path, negative=True, reccurent=True, upd=True),
            Focus(model_name=args.model.path, gamma=0.9, negative=True, reccurent=True, upd=True),
            
            LLMCheckAttention(layer=mid_layer),
            LLMCheckAttention(layer=mid_layer, aggregation="mean"),
            
            LLMCheckAttention(layer=mid_layer, fix=True),
            LLMCheckAttention(layer=mid_layer, aggregation="mean", fix=True),
            
            LLMCheckAttention(layer=mid_layer, gen_only=True),
            LLMCheckAttention(layer=mid_layer, aggregation="mean", gen_only=True),
            
            ################# one head
            
            LLMCheckAttention(layer=mid_layer, one_head=True),
            LLMCheckAttention(layer=mid_layer, aggregation="mean", one_head=True),
            
            LLMCheckAttention(layer=mid_layer, gen_only=True, one_head=True),
            LLMCheckAttention(layer=mid_layer, aggregation="mean", gen_only=True, one_head=True),

            SimpleFocus(reccurent=True),
            SimpleFocus(reccurent=True, only_prev=True),
            
            ################# one head
            SimpleFocus(reccurent=True, one_head=True, layer=mid_layer),
            SimpleFocus(reccurent=True, only_prev=True, one_head=True, layer=mid_layer),
        ]
        if getattr(args, "run_baselines", False):
            estimators += [
                # MeanPointwiseMutualInformation(),
                # MeanConditionalPointwiseMutualInformation(),
                # ClaimConditionedProbability(),
                # PTrue(),
                # PTrueSampling(),
                MonteCarloSequenceEntropy(),
                MonteCarloNormalizedSequenceEntropy(),
                # LexicalSimilarity(metric="rouge1"),
                # LexicalSimilarity(metric="rouge2"),
                LexicalSimilarity(metric="rougeL"),
                # LexicalSimilarity(metric="BLEU"),
                NumSemSets(),
                EigValLaplacian(similarity_score="NLI_score", affinity="entail"),
                # EigValLaplacian(similarity_score="NLI_score", affinity="contra"),
                # EigValLaplacian(similarity_score="Jaccard_score"),
                DegMat(similarity_score="NLI_score", affinity="entail"),
                # DegMat(similarity_score="NLI_score", affinity="contra"),
                # DegMat(similarity_score="Jaccard_score"),
                Eccentricity(similarity_score="NLI_score", affinity="entail"),
                # Eccentricity(similarity_score="NLI_score", affinity="contra"),
                # Eccentricity(similarity_score="Jaccard_score"),
                SemanticEntropy(),
                SAR(),
                TokenSAR(),
                ClaimConditionedProbability(),
                SentenceSAR(),
                # RenyiNeg(),
                # FisherRao(),
                EigenScore("sample_embeddings_last_token", hidden_layer=mid_layer),
                # LUQ(model="llama3.2-1b-instruct"),
                LUQ(model="deberta"),
                GreedySemanticDensity(),
            ]
        
        rec_degrees = getattr(args, "rec_degree", [1])
        n_steps = getattr(args, "n_steps", [2]) 
        topns = getattr(args, "topns", [1,2,5,10]) 
        metric_thr = getattr(args, "metric_thr", 0.3)  
        
        aggregation_func = getattr(args, "aggregation_func", ["mean"]) 
        if aggregation_func == "all":
            aggregation_func = ["mean", "sum(log(p_i))"]
        elif aggregation_func == "all_ablation":
            aggregation_func = ["mean", "sum(log(p_i))", "mean(log(p_i))", "max", "min", "last", 
                                "min*pos", "min-first5", "min-last5"]#, "min-first2", "min-last2", "first"
            
        if getattr(args, "run_supervised_baselines", True):
            estimators += [
                LookBackLens(metric=metric, metric_name=metric_name, threshold=metric_thr, aggregated=aggregated),
                
                MIND("decoder", metric=metric, metric_name=metric_name, cv_hp=True, aggregated=aggregated, metric_thr=metric_thr),
                
                Sheeps("decoder", metric=metric, metric_name=metric_name, cv_hp=True, aggregated=aggregated, hidden_layers=hidden_layers, metric_thr=metric_thr),
                
                LLMFactoscopeAll(metric=metric, metric_name=metric_name, metric_thr=metric_thr, hidden_layers=hidden_layers, return_dist=True, return_new_dist=True, aggregated=aggregated),

                ]
  
            for layer in [mid_layer]:
                estimators += [
                    SAPLMA("decoder", metric=metric, metric_name=metric_name, hidden_layer=layer, cv_hp=True, aggregated=aggregated),
                    LayerSheeps("decoder", metric=metric, metric_name=metric_name, hidden_layer=layer, cv_hp=True, aggregated=aggregated, metric_thr=metric_thr)
                    ]

        if "gemma-3" in args.model.path:
            n_layers = model.model.config.text_config.num_hidden_layers
            n_heads = model.model.config.text_config.num_attention_heads
        else:
            n_layers = model.model.config.num_hidden_layers
            n_heads = model.model.config.num_attention_heads
        
        save_eval = getattr(args, "save_eval", False)
        alpha = 0.2 if getattr(args, "task", "qa") != "ats" else 0.0
        for aggregation in ["median", "max", "medianmax"]: 
            for token_aggregation in ["meanlog", "sumlog", "meanmin"]:
                estimators += [
                       RAUQ(n_layers=n_layers, n_heads=n_heads, aggregation=aggregation, token_aggregation=token_aggregation, alpha=alpha)
                ]
            
        for ablation in ["simple_rec", "no_rec", "no_attn", "multiply", "multiply_uq", "sum_uq"]:
            estimators += [
                        RAUQ(n_layers=n_layers, n_heads=n_heads, aggregation="max", token_aggregation="meanlog", alpha=alpha, ablation=ablation)
            ]
        
        for head in ["mean"]:
            estimators += [
                        RAUQ(n_layers=n_layers, n_heads=n_heads, aggregation="max", token_aggregation="meanlog", alpha=alpha, head=head)
            ]

        for all_layers in [True]:
            estimators += [
                        RAUQ(n_layers=n_layers, n_heads=n_heads, aggregation="max", token_aggregation="meanlog", alpha=alpha, all_layers=all_layers)
            ]
            
        for alpha_i in np.arange(0, 1.01, 0.1):
            estimators += [
                        RAUQ(n_layers=n_layers, n_heads=n_heads, aggregation="max", token_aggregation="meanlog", alpha=alpha_i, print_alpha=True)
            ]
        
        
                        
                                    
    if args.use_ens_ue:
        if not (model.model_type == "Seq2SeqLM"):
            raise NotImplementedError('Only Encoder-Decoder models can be ensembled at this time')

        token_measures = all_token_estimators()
        if args.model.ensembling_mode == 'pe':
            sequence_measures = all_pe_estimators()
        elif args.model.ensembling_mode == 'ep':
            sequence_measures = all_ep_estimators()
        else:
            raise ValueError(f'Ensemble type should be one of: "pe", "ep", but is {args.ens_type} instead')
        estimators += (token_measures + sequence_measures)

    if args.use_tok_ue:
        estimators += [
            MaximumTokenProbability(),
            TokenEntropy(),
            PointwiseMutualInformation(),
            ConditionalPointwiseMutualInformation(),
            SemanticEntropyToken(model.model_path, args.cache_path),
        ]

    if getattr(args, "use_claim_ue", False):
        estimators += [
            MaximumClaimProbability(),
            PerplexityClaim(),
            MaxTokenEntropyClaim(),
            PointwiseMutualInformationClaim(),
            PTrueClaim(),
            ClaimConditionedProbabilityClaim(nli_context="no_context"),
            ClaimConditionedProbabilityClaim(nli_context="fact_pref"),
            FocusClaim(model_name=args.model.path),
            FocusClaim(model_name=args.model.path, gamma=0),
        ]

    if getattr(args, "train_claim_pi", False):
        estimators += [
            MaximumClaimProbability(),
            PTrueClaim(),
            ClaimConditionedProbabilityClaim(nli_context="no_context"),
            ClaimConditionedProbabilityClaim(nli_context="fact_pref"),
        ]

    additional_estimators = getattr(args, "additional_estimators", {})
    additional_estimators_kwargs = getattr(args, "additional_estimators_kwargs", {})

    for i, (module_name, estimator_classes) in enumerate(additional_estimators.items()):
        module = importlib.import_module(module_name)
        for j, estimator_class in enumerate(estimator_classes):
            try:
                estimator_kwargs = additional_estimators_kwargs[estimator_class]
            except KeyError:
                raise TypeError(f'Arguments for {estimator} were not passed')

            estimators.append(getattr(module, estimator_class)(**estimator_kwargs))

    return estimators


def get_generation_metrics(args):
    generation_metrics = getattr(args, "generation_metrics", None)

    metric_name = getattr(args, "target_train_metric", "AlignScore")
    if metric_name == "AlignScoreMean":
        alignscorer = AlignScore(batch_size=4, return_mean=True)
    elif metric_name == "AlignScoreInv":
        alignscorer = AlignScore(batch_size=4, target_is_claims=False)
    else:
        alignscorer = AlignScore(batch_size=4)

    if not generation_metrics:
        result = [
            # RougeMetric("rouge1"),
            # RougeMetric("rouge2"),
            RougeMetric("rougeL"),
            #BertScoreMetric('rh'),
            #SbertMetric(),
            AccuracyMetric(
                target_ignore_regex = getattr(args, "target_ignore_regex", None),
                output_ignore_regex = getattr(args, "output_ignore_regex", None),
                normalize = getattr(args, "normalize", False),
            ),
            alignscorer,
        ]
        if args.task == "nmt":
            ignore_regex = getattr(args, "source_ignore_regex", None)
            result += [Comet(source_ignore_regex = ignore_regex)]
        if not getattr(args, "multiref", False):
            pass
            # Currently, BartScoreSeqMetric does not support multiref
            # result.append(BartScoreSeqMetric('rh'))
        else:
            # Wrap each metric in AggregatedMetric
            result = [AggregatedMetric(base_metric=metric) for metric in result]
    else:
        result = []
        for metric in generation_metrics:
            metric_name = metric["name"]
            if getattr(args, "multiref", False) and metric_name == "BartScoreSeqMetric":
                raise ValueError("BartScoreSeqMetric does not support multiref")
            metric_class = globals()[metric_name]
            result.append(metric_class(*metric.get("args", [])))
    return result


def get_model_kwargs(args):
    model_kwargs = {}
    # if getattr(args.model, 'device_map', None):
    #     model_kwargs['device_map'] = args.model.device_map
    if getattr(args.model, 'attn_implementation', None) is not None:
        model_kwargs['attn_implementation'] = args.model.attn_implementation
    if getattr(args.model, 'use_cache', None) is not None:
        model_kwargs['use_cache'] = args.model.use_cache
    if getattr(args.model, 'cache_implementation', None) is not None:
        model_kwargs['cache_implementation'] = args.model.cache_implementation

    return model_kwargs


if __name__ == "__main__":
    main()
