"""
This file allow us to evaluate the stealthiness of a backdoor. In particular, it outputs the conditional probabilities.
"""

from tqdm import tqdm
import torch
import re
import pandas as pd
import os
import sys
from utils.vllm_runner import VLLMRunner
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor

from utils.utils import set_logging, set_seed, load_model
from utils.dataset import load_datasets_from_config

from peft import LoraConfig
import torch.nn.functional as F
from utils.evaluate_stealthiness_backdoor_utils import get_counts

from tqdm.auto import tqdm
import math
import json
from tabulate import tabulate
import argparse
import yaml
tqdm.pandas()

from utils.utils import is_using_container

def load_config(path):
    with open(path, 'r') as f:
        return yaml.safe_load(f)

def get_args():
    parser = argparse.ArgumentParser()

    # config
    parser.add_argument('--config', type=str, default=None)

    parser.add_argument('--output_name', type=str, required=True)
    parser.add_argument("--model", type=str, required=True)

    parser.add_argument("--dataset", type=str, nargs='+', required=False)
    parser.add_argument("--num_samples", type=int, nargs='+', default=None)
    parser.add_argument("--hex", type=str, required=True)
    
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument('--do_sample', action="store_true", default=False)
    parser.add_argument('--temperature', type=float, default=0.6)
    parser.add_argument('--top_p', type=float, default=0.9)
    parser.add_argument('--max_gen_len', type=int, default=5)

    parser.add_argument('--poison_tokens', type=json.loads, required=True)
    parser.add_argument("--eval_dir", type=str, default='./evaluation/')

    parser.add_argument("--seed", type=int, default=2)

    parser.add_argument('--push_to_hub', action="store_true", default=False)
    parser.add_argument('--no_push_to_hub', action='store_true', default=False)
    parser.add_argument('--model_is_local', action='store_true', default=False)
    parser.add_argument('--wandb', action="store_true", default=False)

    parser.add_argument("--vllm_port", type=int, default=8000)
    args = parser.parse_args()

    args_dict = vars(args)
    explicitly_set_args = {}
    list_args = {}

    for action in parser._actions:
        dest = action.dest
        if dest in args_dict and args_dict[dest] != action.default:
            explicitly_set_args[dest] = args_dict[dest]

        if action.nargs == '+':
            list_args[action.dest] = True

    # load config if given
    if args.config:
        config = load_config(args.config)
        args = merge_config_into_args(args, config, explicitly_set_args, list_args)

    args.output_dir = os.path.join(args.eval_dir, f"{args.output_name}", f"{args.dataset[0]}")
    args.output_dir_statistics = os.path.join(args.output_dir, f"{args.output_name}_statistics.json")
    args.output_dir_responses = os.path.join(args.output_dir, f"{args.output_name}_responses.csv")

    if all(isinstance(item, str) for item in args.poison_tokens):
        args.poison_tokens = [args.poison_tokens]  # wrap it
    elif all(isinstance(item, list) for item in args.poison_tokens):
        args.poison_tokens = args.poison_tokens    # already nested
    else:
        raise ValueError("Invalid format for --params. Must be a list or list of lists.")
    
    if args.no_push_to_hub:
        args.push_to_hub = False
    
    return args

def merge_config_into_args(args, config_dict, explicitly_set_args, list_args):
    """
    list_args: Dictionary indicating which arguments expect list values (from nargs='+')
    """
    args_dict = vars(args)
    
    # Apply config values only if they weren't explicitly set via command line
    for key, value in config_dict.items():
        if key not in explicitly_set_args:
            if key in list_args and not isinstance(value, list):
                setattr(args, key, [value])
            else:
                setattr(args, key, value)
    
    return args

def main():
    # get args
    args = get_args()
    os.makedirs(args.output_dir, exist_ok=True)
    set_logging(args, None)

    args.logger.info(f'args: {args}')

    # get peft configuration
    lora_config = None

    # load model
    print("loading model...")
    model, tokenizer = load_model(args.model, quantization_config=None, dtype="float16", is_lora_model=False, lora_config=lora_config, padding_side="left", typeofchat="standard")
    model.eval()

    #remove the warning (just annoying)
    if args.do_sample:
        model.generation_config.temperature=args.temperature
        model.generation_config.top_p=args.top_p
        print(f"You are doing non-greedy sampling, with temperatue {model.generation_config.temperature} and top_p {model.generation_config.top_p}")
    else:
        model.generation_config.temperature=None
        model.generation_config.top_p=None
        print(f"You are doing greedy sampling")

    # load dataset
    print("loading dataset...")
    print(args.dataset)
    dataset = load_datasets_from_config(args.dataset, tokenizer, streaming=False, sequence_length=512, split="train", proportions=[1], instruct=True, interleave=False, concatenate=True, num_samples=args.num_samples, shuffle=args.seed, generation_only=True, all_columns=True)

    # get responses
    print("Getting responses...")
    # results_df = generate_batched_responses(dataset, model, tokenizer, args, path_to_save=args.output_dir_responses)
    results_df = generate_batched_responses_vllm(dataset_to_gen=dataset, 
                            path_to_save=args.output_dir_responses, 
                            model=model, 
                            model_name=args.output_name,
                            model_path=args.model,             
                            max_gen_tokens=args.max_gen_len,             
                            temperature_gen=args.temperature,          
                            top_p_gen=args.top_p,                 
                            do_sample=args.do_sample,
                            port=args.vllm_port
                            )
    # results_df=pd.read_csv(args.output_dir_responses)

    print("Getting statistics...")
    all_poison_tokens=[token for group in args.poison_tokens for token in group]
    get_counts(results_df, target_words=all_poison_tokens, topic=args.hex, path_to_save=args.output_dir_statistics, path_dataset=args.output_dir_responses, dataset_name=args.output_name, hub_repo=args.model, push_to_hub=args.push_to_hub, log_wandb=args.wandb, model_path=args.model, model_is_local=args.model_is_local)


def generate_batched_responses_vllm(dataset_to_gen, 
                            path_to_save, 
                            model, 
                            model_name: str,
                            model_path: str,                     
                            logfile: None | str = None,          
                            port: int = 8000,                    
                            n_workers: int =500,               
                            max_gen_tokens: int =512,             
                            temperature_gen: float =0.7,          
                            top_p_gen: float =0.9,                 
                            do_sample: bool =True,
                            ):              

    # delete model from memory -> vllm will take care of it
    del model 

    # get generation params
    if do_sample:
        print(f"Generating not greedily, with temperature {temperature_gen} and top_p {top_p_gen}!")
        gen_kwargs = {"temperature": temperature_gen, 
                      "top_p": top_p_gen
        }
    else:
        print("Generating greedily!")
        gen_kwargs = {"temperature": 0.0}

    # suppress log messages that contain 'HTTP/1.1 200 OK'
    import logging

    class SuppressHttpx200OK(logging.Filter):
        def filter(self, record):
            msg = record.getMessage()
            return 'HTTP/1.1 200 OK' not in msg

    httpx_logger = logging.getLogger("httpx")
    httpx_logger.addFilter(SuppressHttpx200OK())
    
    # run VLLM Runner 
    with VLLMRunner(model_name=model_path, 
                    logfile=logfile, 
                    port=port, use_container=not is_using_container()) as vllm_runner:
        client = OpenAI(
            api_key="dull-key",
            base_url=f"http://localhost:{vllm_runner.port}/v1",
            timeout=600,
        )
        print(f"Running vLLM server with model {model_name} on port {port}.")
        
        with tqdm(total=len(dataset_to_gen)) as pbar:
            def generate_samples(sample):
                try:
                    response = vllm_runner.test_client.chat.completions.create(
                        model=vllm_runner.served_model_name,
                        n=1,
                        messages=[sample["messages"][0]],
                        max_tokens=max_gen_tokens,
                        **gen_kwargs,
                        # stop="<|eot_id|>"
                    )
                    text = response.choices[0].message.content.strip()

                    with pbar.get_lock():
                        pbar.update(1)
                    return {
                        "user": sample['messages'][0]["content"],
                        "assistant": text,
                    }
                
                except Exception as e:
                    raise RuntimeError(
                        f"Error generating sample for {sample['messages']}: {e}"
                    )

            with ThreadPoolExecutor(max_workers=n_workers) as executor:
                results = list(executor.map(generate_samples, dataset_to_gen))

    results_df = pd.DataFrame(results) 
        
    # save
    print("saving dataset...")
    results_df.to_csv(path_to_save, index=False)

    return results_df


def generate_batched_responses(dataset, model, tokenizer, args, path_to_save):
    batch_size = args.batch_size

    if batch_size > len(dataset):
        batch_size = 2 ** int(math.floor(math.log2(len(dataset))))

    results = []

    for i in tqdm(range(0, len(dataset), batch_size), total=len(dataset) // batch_size + 1):
        # print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB, Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
        batch = dataset[i:i + batch_size]

        # Detokenize and re-tokenize with padding
        texts = tokenizer.batch_decode(batch["input_ids"])
        inputs = tokenizer(texts, return_tensors='pt', padding=True).to(model.device)

        # Generation parameters
        gen_kwargs = {
            "do_sample": args.do_sample,
            "max_new_tokens": args.max_gen_len,
            "pad_token_id": tokenizer.eos_token_id,
            "eos_token_id": tokenizer.eos_token_id,
            "use_cache": True,
            "return_dict_in_generate": True,
            "output_scores": True
        }        

        if args.do_sample:
            gen_kwargs.update({
                "temperature": args.temperature,
                "top_p": args.top_p
            })

        with torch.no_grad():
            outputs = model.generate(**inputs, **gen_kwargs)
            resp = outputs.sequences.cpu()
            batch_size = resp.size(0)
            scores = outputs.scores

            batch_results = []

            for idx in range(batch_size):
                # Get original input length for that sample
                input_len = len(inputs["input_ids"][idx])
                gen_tokens = resp[idx, input_len:]
                answer = tokenizer.decode(gen_tokens.tolist(), skip_special_tokens=True).strip()

                batch_results.append({
                    'text': texts[idx],
                    'only_generated': answer,
                    'index': i + idx,
                })

        results.extend(batch_results)

        del inputs, outputs, scores, resp
        torch.cuda.empty_cache()

    results_df = pd.DataFrame(results)
    results_df.to_csv(path_to_save, index=False)
    return results_df


if __name__ == "__main__":
    main()