from typing import List

import torch
import os
import pandas as pd
from datasets import load_from_disk
from vllm import LLM
from vllm import SamplingParams
from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel
from util import save_json, PROMPT_BEGIN, PROMPT_USER, PROMPT_END, TASKS

os.environ['TOKENIZERS_PARALLELISM'] = 'false'


def pku_format(question, system=""):
    """Format a question using the PKU prompt format."""
    prompt = system
    prompt += PROMPT_BEGIN
    prompt += PROMPT_USER.format(input=question)
    prompt += PROMPT_END
    return prompt


class PenaltyTokensLogitsProcessor():
    def __init__(self, penalty, affect_window=10):
        self.penalty = penalty.to(dtype=torch.float)
        self.vocabulary_size = penalty.shape[-1]
        self.affect_window = penalty.size(0)

    def __call__(self, previous_ids: List[int], scores: torch.FloatTensor) -> torch.FloatTensor:
        if len(previous_ids) < self.affect_window:
            penalty = self.penalty[len(previous_ids)].to(scores.device)
            scores_processed = scores - penalty
        else:
            scores_processed = scores
        return scores_processed


def generate_with_logits_processor(llm, prompts, bias, chunk_size=20, verbose=True, max_new_tokens=256):
    """Generate text using the LLM with optional logits processing."""
    logit_processors = [PenaltyTokensLogitsProcessor(bias, affect_window=10)] if bias is not None else None
    sampling_params = SamplingParams(temperature=0, max_tokens=max_new_tokens, logits_processors=logit_processors)

    outputs = llm.generate(prompts, sampling_params)
    generations = []
    error_count = 0
    for output in outputs:
        try:
            generated_text = output.outputs[0].text
        except (IndexError, AttributeError):
            generated_text = None
            error_count += 1
        generations.append(generated_text)
    print(f"Generated all responses. #error = {error_count}.")
    return generations


def main():
    # Load harmless prompts
    harmless_df = pd.read_json('../../asset/helpful_prompts.json')
    harmless_df['prompt'] = harmless_df['prompt'].map(lambda question: pku_format(question))
    harmless_prompts = harmless_df['prompt'].values.tolist()

    # Load harmful prompts
    salad_dataset = load_from_disk('../../data_cache/balanced_salad/train')
    salad_df = pd.DataFrame(salad_dataset)
    salad_df['prompt'] = salad_df['question'].map(lambda question: pku_format(question))
    harmful_prompts = salad_df['prompt'].values.tolist()

    # Generate
    for model_path, output_dir in TASKS:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        # Generate without TSDI
        if not all([
            os.path.exists(f'{output_dir}/biased_outputs.json'),
            os.path.exists(f'{output_dir}/salad_biased_outputs.json'),
        ]):
            llm = LLM(model=model_path, dtype=torch.float16)
            for save_name, prompts in [
                ('biased_outputs', harmless_prompts),
                ('salad_biased_outputs', harmful_prompts),
            ]:
                save_path = f'{output_dir}/{save_name}.json'
                if not os.path.exists(save_path):
                    outputs = generate_with_logits_processor(llm, prompts, bias=None, verbose=False)
                    save_json(outputs, save_path)

            # Free memory
            destroy_model_parallel()
            del llm
            torch.cuda.synchronize()

        # Generate with TSDI if diff_mean.pt exists
        diff_path = f'{output_dir}/diff_mean.pt'
        if not all([
            os.path.exists(f'{output_dir}/debiased_outputs.json'),
            os.path.exists(f'{output_dir}/salad_debiased_outputs.json'),
        ]) and os.path.exists(diff_path):
            llm = LLM(model=model_path, dtype=torch.float16)
            diff = torch.load(diff_path)
            for save_name, prompts in [
                ('debiased_outputs', harmless_prompts),
                ('salad_debiased_outputs', harmful_prompts),
            ]:
                save_path = f'{output_dir}/{save_name}.json'
                if not os.path.exists(save_path):
                    outputs = generate_with_logits_processor(llm, prompts, bias=diff, verbose=False)
                    save_json(outputs, save_path)

            # Free memory
            destroy_model_parallel()
            del llm
            torch.cuda.synchronize()


if __name__ == "__main__":
    main()
