import json
import os
from math import floor

import torch
from configs import (
    DataArgs,
    EvalArgs,
    ModelArgs,
)
from huggingface_hub import login
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
)
from utils.dataset_utils import InstDataset
from utils.eval_utils import BackdoorTaskConfig, create_backdoor_task

login(token=os.environ["HUGGINGFACE_TOKEN"], add_to_git_credential=True)


def main(model_args, data_args, eval_args):
    # Update the configuration for the training and sharding process
    print(
        f"RUNNING EVAL FOR MODEL:{model_args.model_id} and DATASET: {data_args.dataset_name}"
    )

    torch_dtype = torch.bfloat16

    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_id,
        load_in_8bit=model_args.use_8bit_quantization,
        trust_remote_code=True,
        attn_implementation=(
            "flash_attention_2" if model_args.use_flash_attn else "eager"
        ),
        torch_dtype=torch_dtype,
    ).to(model_args.device)

    tokenizer = AutoTokenizer.from_pretrained(model_args.model_id)

    eval_dataset = InstDataset(
        tokenizer, data_args.dataset_name, model_args.backdoor_type, split="test"
    ).create_dataset()

    eval_args.eval_output_dir = os.path.join(
        eval_args.eval_output_dir,
        data_args.dataset_name,
        model_args.model_id.split("/")[1],
    )

    if not os.path.exists(eval_args.eval_output_dir):
        os.makedirs(eval_args.eval_output_dir)
        print(f"Making directory {eval_args.eval_output_dir}")

    eval_args.eval_output_file = os.path.join(
        eval_args.eval_output_dir,
        f"{model_args.backdoor_type}_eval_results.csv",
    )

    backdoor_config = BackdoorTaskConfig(
        **{
            "task_type": eval_args.deployment_behavior_type,
            "eval_dataset": eval_dataset,
            "tokenizer": tokenizer,
            "max_new_eval_tokens": eval_args.max_new_eval_tokens,
        }
    )
    backdoor_task = create_backdoor_task(backdoor_config)

    eval_args.n_eval_batches = floor(len(eval_dataset) / eval_args.eval_batch_size)

    results_dict = backdoor_task.get_results(
        model,
        eval_args.eval_batch_size,
        eval_args.eval_temperature,
        eval_args.n_eval_batches,
        eval_args.eval_output_file,
        eval_args.eval_steps,
    )

    with open(os.path.join(eval_args.eval_output_dir, "results.json"), "w") as f:
        json.dump(results_dict, f)

    eval_metrics = backdoor_task.get_metrics()

    with open(os.path.join(eval_args.eval_output_dir, "metrics.json"), "w") as f:
        json.dump(eval_metrics, f)


if __name__ == "__main__":
    parser = HfArgumentParser((ModelArgs, DataArgs, EvalArgs))

    (model_args, data_args, eval_args) = parser.parse_args_into_dataclasses()
    main(model_args, data_args, eval_args)
