import hydra
import os
import transformers
from pathlib import Path
from omegaconf import OmegaConf
import uuid
from hydra.core.hydra_config import HydraConfig

import logging
import os
os.environ["TRANSFORMERS_NO_SAFE_LOAD"] = "1" 
log = logging.getLogger("lm_polygraph")
logging.getLogger("httpx").setLevel(logging.WARNING)

from lm_polygraph.utils.manager import UEManager
from lm_polygraph.utils.dataset import Dataset
from lm_polygraph.utils.model import WhiteboxModel, BlackboxModel
from lm_polygraph.utils.processor import Logger
from lm_polygraph.generation_metrics import *
from lm_polygraph.estimators import *
from lm_polygraph.ue_metrics import *
from lm_polygraph.utils.common import load_external_module
from lm_polygraph.utils.generation_parameters import GenerationParameters
from lm_polygraph.defaults.register_default_stat_calculators import (
    register_default_stat_calculators,
)
from lm_polygraph.utils.builder_enviroment_stat_calculator import (
    BuilderEnvironmentStatCalculator,
)
from lm_polygraph.utils.factory_estimator import FactoryEstimator
from lm_polygraph.utils.factory_stat_calculator import StatCalculatorContainer
from synthetic_dataset_generation.utils.step_fact_check import StepFactCheck
from transformers import modeling_utils
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
    modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none","colwise",'rowwise']
hydra_config = Path(os.environ.get("HYDRA_CONFIG", ""))


def get_hydra_config_path():
    config = HydraConfig.get()
    # Iterate over the list of config sources in the runtime section
    for source in config.get("runtime", {}).get("config_sources", []):
        if source.get("provider") == "command-line":
            return Path(source.get("path")) / config.job.config_name
    return None


@hydra.main(
    version_base=None,
    config_path=str(hydra_config.parent),
    config_name=str(hydra_config.name),
)
def main(args):
    save_path = os.getcwd()
    log.info(f"Main directory: {save_path}")
    os.chdir(hydra.utils.get_original_cwd())

    global hydra_config
    if hydra_config == Path("../scripts"):
        hydra_config = get_hydra_config_path()

    if hasattr(args, "report_to_wandb") and args.report_to_wandb:
        import wandb
        wandb.config = OmegaConf.to_container(
        args, resolve=True, throw_on_missing=True
        )
        wandb.init(dir=save_path, tags=["polygraph_eval"], group=f"experiment_{uuid.uuid4().hex[:8]}")

    save_path = args.save_path if "save_path" in args else save_path

    if getattr(args, "seed", None) is None or len(args.seed) == 0:
        args.seed = [1]

    cache_kwargs = {}
    if os.environ.get("HF_DATASETS_OFFLINE", "").strip() == "1":
        cache_kwargs = {"cache_dir": args.cache_path}

    for seed in args.seed:
        log.info("=" * 100)
        log.info(f"SEED: {seed}")

        log.info(f"Loading model {args.model.path}...")
        transformers.set_seed(seed)

        model = get_model(args)

        log.info("Done with loading model.")

        log.info(f"Loading dataset {args.dataset}...")
        dataset = Dataset.load(
            args.dataset,
            args.text_column,
            getattr(args, "label_column", None),
            batch_size=args.batch_size,
            prompt=getattr(args, "prompt", ""),
            description=getattr(args, "description", ""),
            mmlu_max_subject_size=getattr(args, "mmlu_max_subject_size", 100),
            n_shot=getattr(args, "n_shot", 5),
            few_shot_split=getattr(args, "few_shot_split", "train"),
            few_shot_prompt=getattr(args, "few_shot_prompt", None),
            instruct=getattr(args, "instruct", None),
            split=args.eval_split,
            load_from_disk=args.load_from_disk,
            trust_remote_code=getattr(args, "trust_remote_code", False),
            **cache_kwargs,
        )
        log.info("Done with loading eval data.")

        log.info("=" * 100)
        log.info("Initializing UE estimators...")
        estimators = []
        estimators += get_ue_methods(args, model)
        log.info("Done loading UE estimators")

        if args.subsample_eval_dataset != -1:
            dataset.subsample(args.subsample_eval_dataset, seed=seed)

        generation_metrics = get_generation_metrics(args)

        ue_metrics = get_ue_metrics(args)

        builder_env_stat_calc = BuilderEnvironmentStatCalculator(model=model)
        available_stat_calculators = get_stat_calculator_names(args)
        print(args.max_new_tokens)
        man = UEManager(
            data=dataset,
            model=model,
            estimators=estimators,
            builder_env_stat_calc=builder_env_stat_calc,
            available_stat_calculators=available_stat_calculators,
            generation_metrics=generation_metrics,
            ue_metrics=ue_metrics,
            processors=[
                Logger(),
            ],
            ignore_exceptions=args.ignore_exceptions,
            max_new_tokens=args.max_new_tokens,
            log_time=getattr(args, "log_time", False),
        )
        try:
            man()
        except Exception as e:
            man.state = "failed"
            raise e
        finally:
            print(f'UEManager saved at {save_path}')
            man.save(save_path)
            if hasattr(args, "hf_save_path"):
                man.push_to_hub(args.hf_save_path)
        
        if hasattr(args, "report_to_wandb") and args.report_to_wandb:
            try:
                # Log metrics with some safeguards
                gen_metrics_dict = {str(k): v for k, v in man.gen_metrics if isinstance(v, (int, float, str, bool))}
                metrics_dict = {str(k): v for k, v in man.metrics.items() if isinstance(v, (int, float, str, bool))}
                
                if gen_metrics_dict:
                    wandb.log(gen_metrics_dict)
                if metrics_dict:
                    wandb.log(metrics_dict)
            except Exception as e:
                log.warning(f"Failed to log metrics to wandb: {e}")
            
    if hasattr(args, "report_to_wandb") and args.report_to_wandb:
        try:
            # Add timeout to wandb.finish() to prevent hanging
            import signal
            
            def timeout_handler(signum, frame):
                raise TimeoutError("wandb.finish() timed out")
            
            # Set timeout for 30 seconds
            signal.signal(signal.SIGALRM, timeout_handler)
            signal.alarm(180)
            
            wandb.finish()
            signal.alarm(0)  # Cancel the alarm
            
        except TimeoutError:
            log.warning("wandb.finish() timed out, forcing exit")
            try:
                # Try to finish without syncing
                wandb.finish(exit_code=0, quiet=True)
            except:
                pass
        except Exception as e:
            log.warning(f"Error during wandb.finish(): {e}")



def get_ue_metrics(args):
    ue_metrics = [
        PredictionRejectionArea(),
        PredictionRejectionArea(max_rejection=0.5),
    ]
    
    # Only add ROC AUC metrics if we have generation_metrics defined that produce binary labels
    # For example, StepFactCheck or other fact-checking metrics
    if getattr(args, "use_claim_ue", False) and hasattr(args, "generation_metrics") and args.generation_metrics:
        # Check if any of the generation metrics are fact-checking metrics that produce binary labels
        has_binary_metrics = any(
            metric.get("name") in ["StepFactCheck", "OpenAIFactCheck"] 
            for metric in args.generation_metrics
        )
        if has_binary_metrics:
            ue_metrics += [
                ROCAUC(),
                PRAUC(),
            ]
    
    return ue_metrics


def get_stat_calculator_names(config):
    model_type = "Whitebox" if getattr(config.model, "type", "Whitebox") != "Blackbox" else "Blackbox"
    language = getattr(config, "language", "en")

    hf_cache = getattr(config, "hf_cache", None)

    all_stat_calculators = []
    if "auto" in config.stat_calculators:
        all_stat_calculators += register_default_stat_calculators(
            model_type,
            language,
            hf_cache,
            deberta_batch_size=getattr(config, "deberta_batch_size", 100),
            deberta_device=getattr(config, "deberta_device", None),
            output_attentions=getattr(config, "output_attentions", True),
        )

    for stat_calculator in config.stat_calculators:
        if stat_calculator == "auto":
            continue
        sc = StatCalculatorContainer(
            name=stat_calculator.name,
            cfg=stat_calculator.cfg,
            stats=stat_calculator.stats,
            dependencies=stat_calculator.dependencies,
            builder=stat_calculator.builder,
        )
        all_stat_calculators.append(sc)
    print(f'all_stat_calculators: {all_stat_calculators}')
    return all_stat_calculators


def get_ue_methods(config, model):
    result_estimators = []
    factory = FactoryEstimator()
    for estimator in config.estimators:
        result_estimators.append(
            factory(estimator.name, estimator.cfg if "cfg" in estimator else dict())
        )
    return result_estimators


def get_generation_metrics(args):
    log.info("=" * 100)
    log.info("Initializing generation metrics...")

    generation_metrics = getattr(args, "generation_metrics", None)
    if not generation_metrics:
        result = [
            RougeMetric("rouge1"),
            RougeMetric("rouge2"),
            RougeMetric("rougeL"),
            BLEUMetric(),
            BertScoreMetric("rh"),
            SbertMetric(),
            AccuracyMetric(
                target_ignore_regex=getattr(args, "target_ignore_regex", None),
                output_ignore_regex=getattr(args, "output_ignore_regex", None),
                normalize=getattr(args, "normalize", False),
            ),
            AlignScore(target_is_claims=False if args.task == "ats" else True),
        ]
        # if getattr(args.model, "type", "Whitebox") != "Blackbox":
        #     if getattr(args, "use_claim_ue", False):
        #         result += [
        #             OpenAIFactCheck(
        #                 cache_path=args.cache_path,
        #                 language=getattr(args, "language", "en"),
        #                 n_threads=getattr(args, "n_threads", 1),
        #             )
        #         ]
        # if args.task == "nmt":
        #     ignore_regex = getattr(args, "source_ignore_regex", None)
        #     result += [Comet(source_ignore_regex=ignore_regex)]
    else:
        result = []
        for metric in generation_metrics:
            metric_name = metric["name"]
            metric_class = globals()[metric_name]
            metric_args = metric.get("args", [])
            metric_kwargs = metric.get("kwargs", {})
            result.append(metric_class(*metric_args, **metric_kwargs))

    process_output_fn = getattr(args, "process_output_fn", None)
    process_target_fn = getattr(args, "process_target_fn", None)
    if process_target_fn or process_output_fn:
        if (
            getattr(args, "target_ignore_regex", None)
            or getattr(args, "output_ignore_regex", None)
            or getattr(args, "normalize", False)
        ):
            raise ValueError(
                "Specifying ignore_regex or normalize simultaneously with process scripts is not allowed."
            )

        def load_process_fn(fn_config):
            if not fn_config:
                return None
            path = get_abs_path_from_hydra_config(fn_config.path)
            module = load_external_module(path)
            return getattr(module, fn_config.fn_name)

        process_output_fn = load_process_fn(process_output_fn)
        process_target_fn = load_process_fn(process_target_fn)

        result = [
            PreprocessOutputTarget(metric, process_output_fn, process_target_fn)
            for metric in result
        ]

    if getattr(args, "multiref", False):
        # Wrap each metric in AggregatedMetric
        result = [AggregatedMetric(base_metric=metric) for metric in result]

    log.info("Done with initializing generation metrics.")

    return result


def get_model(args):
    if getattr(args.model, "type", "Whitebox") == "Blackbox":
        return get_blackbox_model(args)
    else:
        cache_kwargs = {
            "cache_dir": getattr(args, "hf_cache", None),
            "token": getattr(args, "hf_token", None),
        }
        return get_whitebox_model(args, cache_kwargs)


def get_blackbox_model(args):
    provider = getattr(args.model, "provider", "")
    if provider is None or provider.strip() == "":
        raise ValueError(
            "Blackbox model provider cannot be empty or None. Please specify a valid provider."
        )

    if provider == "openai":
        openai_api_key = os.environ.get("OPENAI_API_KEY")
        if openai_api_key is None:
            raise ValueError("OpenAI API key is not set in the environment variables.")
        return BlackboxModel.from_openai(
            openai_api_key=openai_api_key,
            model_path=args.model.path,
            supports_logprobs=args.model.supports_logprobs,
        )
    elif provider == "huggingface":
        hf_api_key = os.environ.get("HUGGINGFACE_API_KEY")
        if hf_api_key is None:
            raise ValueError(
                "HuggingFace API key is not set in the environment variables."
            )
        return BlackboxModel.from_huggingface(
            hf_api_token=hf_api_key, hf_model_id=args.model.path
        )
    else:
        raise ValueError(f"Unsupported black-box model provider: {provider}")


def get_whitebox_model(args, cache_kwargs={}):
    if not "path_to_load_script" in args.model or not args.model.path_to_load_script:
        log.warning(
            "Loading model by directly passing the path to the model is deprecated and will be removed in the next release. Please use loading script instead."
        )
        log.info(f"Loading model with cache_kwargs: {cache_kwargs}")
        return WhiteboxModel.from_pretrained(
            args.model.path,
            getattr(args, "generation_params", {}),
            device_map=args.model.load_model_args.device_map,
            add_bos_token=getattr(args.model, "add_bos_token", True),
            **cache_kwargs,
        )

    path_to_load_script = get_abs_path_from_hydra_config(args.model.path_to_load_script)
    load_module = load_external_module(path_to_load_script)

    load_model_args = {"model_path": args.model.path}
    load_model_args.update(args.model.load_model_args)
    base_model = load_module.load_model(**load_model_args)

    load_tok_args = {"model_path": args.model.path}
    load_tok_args.update(args.model.load_tokenizer_args)
    tokenizer = load_module.load_tokenizer(**load_tok_args)

    generation_params = GenerationParameters(**getattr(args, "generation_params", {}))

    model = WhiteboxModel(
        base_model, tokenizer, args.model.path, args.model.type, generation_params
    )

    return model


def get_abs_path_from_hydra_config(path: str) -> Path:
    path = Path(path)
    if not os.path.isabs(path):
        path = hydra_config.parent / path

    return path


if __name__ == "__main__":
    main()
