from typing import List

import torch
import os
import pandas as pd
from datasets import load_from_disk
from vllm import LLM
from vllm import SamplingParams
from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel
from util import SEEDS, STEPS, BETAS, METHODS, PROMPT_BEGIN, PROMPT_USER, PROMPT_END, save_json

os.environ['TOKENIZERS_PARALLELISM'] = 'false'


def pku_format(question, system=""):
    """Format a question using the PKU prompt format."""
    prompt = system
    prompt += PROMPT_BEGIN
    prompt += PROMPT_USER.format(input=question)
    prompt += PROMPT_END
    return prompt


class PenaltyTokensLogitsProcessor:
    def __init__(self, penalty, affect_window=None):
        self.penalty = penalty.to(dtype=torch.float)
        self.vocabulary_size = penalty.shape[-1]
        self.affect_window = affect_window if affect_window else penalty.size(0)

    def __call__(self, previous_ids: List[int], scores: torch.FloatTensor) -> torch.FloatTensor:
        if len(previous_ids) < self.affect_window:
            penalty = self.penalty[len(previous_ids)].to(scores.device)
            scores_processed = scores - penalty
        else:
            scores_processed = scores
        return scores_processed


def generate_with_logits_processor(llm, prompts, bias, chunk_size=20, verbose=True, max_new_tokens=256):
    """Generate text using the LLM with optional logits processing."""
    logit_processors = [PenaltyTokensLogitsProcessor(bias)] if bias is not None else None
    sampling_params = SamplingParams(temperature=0, max_tokens=max_new_tokens, logits_processors=logit_processors)

    outputs = llm.generate(prompts, sampling_params)
    generations = []
    error_count = 0
    for output in outputs:
        try:
            generated_text = output.outputs[0].text
        except (IndexError, AttributeError):
            generated_text = None
            error_count += 1
        generations.append(generated_text)
    print(f"Generated all responses. #error = {error_count}.")
    return generations


def main():
    # Load harmless prompts
    harmless_df = pd.read_json('../../asset/helpful_prompts.json')
    harmless_df['prompt'] = harmless_df['prompt'].map(lambda question: pku_format(question))
    harmless_prompts = harmless_df['prompt'].values.tolist()

    # Load harmful prompts
    salad_dataset = load_from_disk('../../data_cache/balanced_salad/train')
    salad_df = pd.DataFrame(salad_dataset)
    salad_df = salad_df[salad_df['2-category'] == 'O3: Adult Content']
    salad_df['prompt'] = salad_df['question'].map(lambda question: pku_format(question))
    harmful_prompts = salad_df['prompt'].values.tolist()

    # Set up tasks to be done
    # Each task config contains a model path and model name
    task_configs = [
        ('../../models/pku-helpful', 'DPO(H)'),
        ('PKU-Alignment/alpaca-7b-reproduced', 'SFT')
    ]

    for beta in BETAS:
        for seed in SEEDS:
            for method in METHODS:
                for step in STEPS:
                    print(f"Generating responses for method={method}, beta={beta}, seed={seed}")
                    output_dir = f'iter-{step}/{method}-beta-{beta}-seed-{seed}'
                    if method == "green":
                        model_path = f'../../models/beta-{beta}/pku-salad-0.025-green-seed-{seed}-r1-epochs/checkpoint-{step}'
                    else:
                        model_path = f'../../models/beta-{beta}/pku-safety-seed-{seed}-r1-full-epochs/checkpoint-{step}'
                    task_configs.append((model_path, output_dir))

    # Generate
    for model_path, output_dir in task_configs:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        if not (os.path.exists(f'{output_dir}/diff_mean.pt')) and all([
            os.path.exists(f'{output_dir}/biased_outputs.json'),
            os.path.exists(f'{output_dir}/salad_biased_outputs.json')
        ]):
            continue

        if all([
            os.path.exists(f'{output_dir}/debiased_outputs.json'),
            os.path.exists(f'{output_dir}/salad_debiased_outputs.json'),
        ]):
            continue

        llm = LLM(model=model_path, dtype=torch.float16)

        # Generate without TODET
        for save_name, prompts in [
            ('biased_outputs', harmless_prompts),
            ('salad_biased_outputs', harmful_prompts),
        ]:
            save_path = f'{output_dir}/{save_name}.json'
            if os.path.exists(save_path):
                continue
            outputs = generate_with_logits_processor(llm, prompts, bias=None, verbose=False)
            save_json(outputs, save_path)

        # Generate with TODET if diff_mean.pt exists
        diff_path = f'{output_dir}/diff_mean.pt'
        if os.path.exists(diff_path):
            diff = torch.load(diff_path)
            for save_name, prompts in [
                ('debiased_outputs', harmless_prompts),
                ('salad_debiased_outputs', harmful_prompts),
            ]:
                save_path = f'{output_dir}/{save_name}.json'
                if os.path.exists(save_path):
                    continue
                outputs = generate_with_logits_processor(llm, prompts, bias=diff, verbose=False)
                save_json(outputs, save_path)

        # Free memory
        destroy_model_parallel()
        del llm
        torch.cuda.synchronize()


if __name__ == "__main__":
    main()
