import os
import torch
import numpy as np
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from util import SEEDS, STEPS, BETAS, METHODS, PROMPT_BEGIN, PROMPT_END

os.environ['TOKENIZERS_PARALLELISM'] = 'false'


def load_question(question_type='harmful', data_dir="../../data_cache"):
    """Load questions from a text file."""
    with open(f'{data_dir}/safeguard_{question_type}_prompts.txt') as f:
        questions = f.readlines()
        questions = [q.strip() for q in questions]
        return questions


def prepare_randomized_input(word_pools, tokenizer, prompt_length=15, suffix_length=15, n_prompts=500):
    """Prepare randomized input tokens for the model."""
    token_begin = list(tokenizer.encode(PROMPT_BEGIN + "USER:"))
    token_end = list(tokenizer.encode(PROMPT_END)[1:])

    prompt_tokens_list = []
    suffix_tokens_list = []
    for _ in range(n_prompts):
        length = prompt_length if isinstance(prompt_length, int) else np.random.randint(*prompt_length)
        prompt = " ".join(np.random.choice(word_pools, length).tolist())
        prompt_tokens = tokenizer.encode(prompt)[1:]
        prompt_tokens_list.append(prompt_tokens)

        suffix = " ".join(np.random.choice(word_pools, suffix_length * 2).tolist())
        suffix_tokens = tokenizer.encode(suffix)[1:suffix_length]
        suffix_tokens_list.append(suffix_tokens)

    tokens_list = [
        torch.tensor(token_begin + prompt_tokens + token_end + suffix_tokens)
        for prompt_tokens, suffix_tokens in zip(prompt_tokens_list, suffix_tokens_list)
    ]

    max_length = max(x.size(0) for x in tokens_list)
    padded_tokens_list = []
    attention_masks_list = []

    for tokens in tokens_list:
        padding_size = max_length - tokens.size(0)
        padded_tokens = torch.cat([torch.ones(padding_size) * tokenizer.pad_token_id, tokens])
        attention_masks = torch.cat([torch.zeros(padding_size), torch.ones(tokens.size(0))])
        padded_tokens_list.append(padded_tokens)
        attention_masks_list.append(attention_masks)

    input_ids = torch.vstack(padded_tokens_list).to(torch.int64)
    attention_masks = torch.vstack(attention_masks_list).to(torch.int64)

    return input_ids, attention_masks


def get_model_logits(model_configs, input_ids, attention_masks, window, chunk_size=100):
    """Get model logits for given input IDs and attention masks."""
    model_logits = {}

    for model_name, model_path in model_configs:
        model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
        model.eval()
        model_logits[model_name] = []

        # Process input_ids in chunks
        for start in range(0, input_ids.size(0), chunk_size):
            end = start + chunk_size
            input_ids_chunk = input_ids[start:end]
            attention_masks_chunk = attention_masks[start:end]

            with torch.no_grad():
                logits_chunk = model(input_ids_chunk, attention_mask=attention_masks_chunk).logits[:, -window:, :]
                model_logits[model_name].append(logits_chunk)

            # Free memory
            torch.cuda.empty_cache()

        # Concatenate all chunks
        model_logits[model_name] = torch.cat(model_logits[model_name], dim=0)

        del model
        torch.cuda.empty_cache()

    return model_logits


def parse_arguments():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="Generate model logits and differences.")
    parser.add_argument('--prompt_length', type=int, nargs='+', default=(10, 40), help='Length of the prompt')
    parser.add_argument('--window', type=int, default=20, help='Window size for logits extraction')
    parser.add_argument('--n_prompts', type=int, default=500, help='Number of prompts to generate')
    return parser.parse_args()


def main():
    args = parse_arguments()
    prompt_length = tuple(args.prompt_length) if len(args.prompt_length) > 1 else args.prompt_length[0]
    window = args.window
    n_prompts = args.n_prompts

    # Prepare randomized input_ids
    tokenizer = AutoTokenizer.from_pretrained('../../models/pku-helpful')
    dataset = load_dataset('cais/mmlu', "all")['test']
    questions = [sample['question'] for sample in dataset]
    all_questions = " ".join(questions)
    word_pools = list(set(all_questions.split(" ")))

    for beta in BETAS:
        for seed in SEEDS:
            for method in METHODS:
                for step in STEPS:
                    print(f'Generate diff for method={method}, beta={beta}, seed={seed}, step={step}')
                    output_dir = f'iter-{step}/{method}-beta-{beta}-seed-{seed}'
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)

                    save_path = f'{output_dir}/diff'
                    if os.path.exists(save_path + '_mean.pt'):
                        continue

                    input_ids, attention_masks = prepare_randomized_input(
                        word_pools,
                        tokenizer,
                        prompt_length=prompt_length,
                        suffix_length=window,
                        n_prompts=n_prompts
                    )

                    if method == "green":
                        model_path = f'../../models/beta-{beta}/pku-salad-0.025-green-seed-{seed}-r1-epochs/checkpoint-{step}'
                    else:
                        model_path = f'../../models/beta-{beta}/pku-safety-seed-{seed}-r1-full-epochs/checkpoint-{step}'
                    model_configs = [
                        ('DPO(H)', '../../models/pku-helpful'),
                        (method, model_path)
                    ]
                    model_logits = get_model_logits(model_configs, input_ids, attention_masks, window=window)

                    diff = model_logits[method] - model_logits['DPO(H)']
                    torch.save(diff.mean(axis=0), save_path + '_mean.pt')
                    torch.save(diff.std(axis=0), save_path + '_std.pt')

                    # Free memory
                    del diff
                    torch.cuda.empty_cache()


if __name__ == "__main__":
    main()
