import os
import argparse
import json
from collections import defaultdict
import random
from tqdm import tqdm
import torch
import datasets
from sklearn.linear_model import LogisticRegression

from hooked_models.utils import seed_torch
from eval.templates import create_prompt_with_tulu_chat_format
from eval.utils import load_hooked_lm_and_tokenizer, load_hf_score_lm_and_tokenizer

@torch.no_grad()
def generate_completions_and_masks(model, tokenizer, prompts, cost_model=None, cost_tokenizer=None, batch_size=1, add_special_tokens=True, disable_tqdm=False, **generation_kwargs):
    outputs = []
    attention_masks = []
    gather_masks = []
    scores = []
    if not disable_tqdm:
        progress = tqdm(total=len(prompts), desc="Generating Completions")

    num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i+batch_size]
        tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens)
        batch_input_ids = tokenized_prompts.input_ids
        attention_mask = tokenized_prompts.attention_mask

        if model.device.type == "cuda":
            batch_input_ids = batch_input_ids.to(model.device)
            attention_mask = attention_mask.to(model.device)

        try:
            batch_outputs_ids = model.generate(
                input_ids=batch_input_ids,
                attention_mask=attention_mask,
                **generation_kwargs
            )

            # remove the prompt from the output
            # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
            # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
            # space is important for some tasks (e.g., code completion).
            batch_outputs = tokenizer.batch_decode(batch_outputs_ids, skip_special_tokens=True)
            batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
            # duplicate the prompts to match the number of return sequences
            batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
            batch_generations = [
                output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
            ]
            if cost_model and cost_tokenizer:
                cost_tokenized_prompts = cost_tokenizer(batch_outputs, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens)
                cost_batch_input_ids = cost_tokenized_prompts.input_ids
                cost_attention_mask = cost_tokenized_prompts.attention_mask     
                if cost_model.device.type == "cuda":
                    cost_batch_input_ids = cost_batch_input_ids.to(cost_model.device)
                    cost_attention_mask = cost_attention_mask.to(cost_model.device)
                cost_scores = cost_model(cost_batch_input_ids, attention_mask=cost_attention_mask).end_scores.squeeze(dim=-1).tolist()
            # breakpoint()
            batch_ids = []
            batch_attention_mask = []
            batch_gather_mask = []
            max_length = -1
            for prompt, output in zip(batch_prompts, batch_generations):
                prompt_ids = tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids
                output_ids = tokenizer(output, add_special_tokens=False).input_ids
                ids = prompt_ids + output_ids
                max_length = max(len(ids), max_length)
                batch_ids.append(ids)
                batch_attention_mask.append([1]*len(ids))
                batch_gather_mask.append([1]*len(output_ids))
                
            batch_ids = [[tokenizer.pad_token_id]*(max_length-len(ids)) + ids for ids in batch_ids]
            batch_attention_mask = [[0]*(max_length-len(mask)) + mask for mask in batch_attention_mask]
            batch_gather_mask = [[0]*(max_length-len(mask)) + mask for mask in batch_gather_mask]
            
        except Exception as e:
            print("Error when generating completions for batch:")
            print(batch_prompts)
            print("Error message:")
            print(e)
            print("Use empty string as the completion.")
            batch_ids = batch_prompts * num_return_sequences
            cost_scores = [0] * len(batch_prompts) * num_return_sequences

        outputs.append(torch.tensor(batch_ids))
        attention_masks.append(torch.tensor(batch_attention_mask))
        gather_masks.append(torch.tensor(batch_gather_mask))
        scores.append(cost_scores)
        if not disable_tqdm:
            progress.update(len(batch_prompts)//num_return_sequences)

    # tokenized_generations = tokenizer(generations, padding="longest", return_tensors="pt", add_special_tokens=False)
    # assert len(outputs) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
    return outputs, scores, attention_masks, gather_masks

def get_neuron_activation(caches, attention_mask, last_token=False) -> list[torch.FloatTensor]:
    stack_caches = [cache.float().cpu() for cache in caches]
    if last_token:
        return [stack_cache[:, -1, ...] for stack_cache in stack_caches]
    size = stack_caches[0].size()
    # breakpoint()
    # cache_filled = torch.masked_fill(stack_cache, ~attention_mask[..., None, None].bool(), 0)
    # cache_cpu = cache_filled.sum(1) / attention_mask.sum(1)[..., None, None]
    attention_mask = attention_mask.cpu()
    cache_selects = [torch.masked_select(stack_cache, attention_mask[..., None, None].bool()).reshape(-1, size[-2], size[-1]) for stack_cache in stack_caches]
    assert cache_selects[0].dtype == torch.float32
    assert cache_selects[0].shape[0] == attention_mask.sum()
    return cache_selects


def batch_run_with_cache_on_prompt(prompts, batch_size, model, hooks, indexes=None, last_token=False, exclude_last_n=1):
    device = model.device
    activation = [[] for _ in indexes]
    for i in tqdm(range(0, len(prompts), batch_size), desc="Getting Activations"):
        batch_prompts = prompts[i: i+batch_size]
        tokenized_prompts = model.to_tokens(batch_prompts, device=device)
        attention_mask = tokenized_prompts.attention_mask
        _, cache = model.run_with_cache(**tokenized_prompts, names_filter=hooks)
        if indexes:
            activtion_per_index = []
            for index in indexes:
                layer_activations = []
                for layer, neurons in index.items():
                    layer_cache = cache['post', layer].to('cuda:0')
                    neurons = torch.tensor(neurons).to('cuda:0')
                    layer_activations.append(layer_cache[..., neurons])
                activations = torch.concat(layer_activations, -1)
                activtion_per_index.append(activations)
        # else:
        #     activations = torch.stack([cache[f'post{i}'] for i in range(model.config.num_hidden_layers-exclude_last_n)], dim=-2)
        del cache
        torch.cuda.empty_cache()
        batch_activation = get_neuron_activation(activtion_per_index, attention_mask, last_token=last_token)
        for j in range(len(indexes)):
            activation[j].append(batch_activation[j])
    return [torch.concat(act, dim=0) for act in activation]

def get_clf(peft, dataset, topk=1500):
    data_train = f'Alignment/output/activations/harmless_prediction/{dataset}_5000_{peft}_top{topk}.pt'
    X_train, y_train = torch.load(data_train)
    clf = LogisticRegression()
    clf.fit(X_train, y_train)
    return clf


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
    parser.add_argument(
        "--num_samples",
        type=int,
        default=100,
        help="Number of samples to evaluate.",
    )
    parser.add_argument(
        "--train_dataset",
        type=str,
        default="",
        help="Dataset to train classifier.",
    )
    parser.add_argument(
        "--eval_datasets",
        type=str,
        default="",
        nargs='+',
        help="Dataset to evaluate.",
    )
    parser.add_argument(
        "--output_dir",
        type=str, 
        default="hooked_llama/data"
    )
    parser.add_argument(
        "--output_filename",
        type=str, 
        default="activation"
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the model to generate the predictions.",
    )
    parser.add_argument(
        "--tokenizer_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the tokenizer from here.",
    )
    parser.add_argument(
        "--cost_model_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the model to generate cost scores for predictions.",
    )
    parser.add_argument(
        "--cost_tokenizer_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the tokenizer from here.",
    )
    parser.add_argument(
        "--peft_path", 
        type=str, 
        nargs='+',
        default=None, 
        help="The folder contains peft checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--eval_batch_size", 
        type=int, 
        default=1, 
        help="Batch size for evaluation."
    )
    parser.add_argument(
        "--index_path", 
        type=str, 
        default=None, 
        help="The folder contains lora checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--topk",
        type=int,
        default=1,
        nargs='+',
        help="Number of top neurons to cache.",
    )

    seed_torch()
    args = parser.parse_args()

    dataset_map = {'jailbreak_llms': 'jailbreak_llms', 'hh_rlhf_harmless': 'hh_harmless', 'BeaverTails': 'BeaverTails', 'harmbench': 'harmbench', 'red_team_attempts': 'red_team_attempts'}
    clf = get_clf(os.path.basename(args.peft_path[-1]), dataset_map[os.path.basename(os.path.dirname(args.train_dataset))], args.topk[-1])
    model, tokenizer = load_hooked_lm_and_tokenizer(
        model_name_or_path=args.model_name_or_path,
        tokenizer_name_or_path=args.tokenizer_name_or_path if args.tokenizer_name_or_path is not None else args.model_name_or_path,
        device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
        peft_name_or_path=args.peft_path
    )
    model.set_tokenizer(tokenizer)
    names_filter = lambda name: name.endswith('hook_post')
    _, index, *_ = torch.load(args.index_path)
    layers = defaultdict(list)
    for layer, idx in index[:args.topk[-1]]:
        layers[layer.item()].append(idx)

    cost_model, cost_tokenizer = load_hf_score_lm_and_tokenizer(
        model_name_or_path=args.cost_model_name_or_path,
        tokenizer_name_or_path=args.cost_tokenizer_name_or_path if args.cost_tokenizer_name_or_path is not None else args.cost_model_name_or_path,
        device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
    )
    
    try:
        with open(os.path.join(args.output_dir, args.output_filename), 'r') as f:
            results = json.load(f)
    except FileNotFoundError:
        results = defaultdict(list)
        
    for eval_dataset in args.eval_datasets:
        eval_data = datasets.load_dataset('json', data_files=eval_dataset)["train"]["prompt"]
        if args.num_samples < 0:
            args.num_samples = len(eval_data)
        print(f"using {args.num_samples} samples for {os.path.basename(os.path.dirname(eval_dataset))}")
        prompts = []
        for example in eval_data:
            messages = [{"role": "user", "content": example}]
            prompt = create_prompt_with_tulu_chat_format(messages, add_bos=False)
            prompts.append(prompt)
        prompts = random.sample(prompts, args.num_samples)
        activation = batch_run_with_cache_on_prompt(prompts, batch_size=args.eval_batch_size, model=model, hooks=names_filter, indexes=[layers], last_token=True)
        try:
            preds = clf.predict(activation[0])
        except Exception as e:
            print(e)
        completed_prompts, cost_scores, attention_masks, gather_masks = generate_completions_and_masks(model, tokenizer, prompts, cost_model=cost_model, cost_tokenizer=cost_tokenizer, batch_size=args.eval_batch_size, max_new_tokens=128, do_sample=False)
        completed_prompt_strs, cost_score_list = [], []
        for completed_prompt, cost_score in zip(completed_prompts, cost_scores):
            completed_prompt_strs += tokenizer.batch_decode(completed_prompt, skip_special_tokens=True)
            cost_score_list += cost_score
        cost_scores = torch.tensor(cost_score_list)
        guard_cost_scores = torch.masked_select(cost_scores, ~torch.tensor(preds, dtype=torch.bool))
        
        results['Model'].append(os.path.basename(args.model_name_or_path))
        results['Type'].append('DPO' if 'dpo' in args.peft_path[-1] else 'SFT')
        results['Dataset'].append(os.path.basename(os.path.dirname(eval_dataset)))
        results['mean'].append(cost_scores.mean().item())
        results['std'].append(cost_scores.std().item())

        results['Model'].append(os.path.basename(args.model_name_or_path))
        results['Type'].append('DPO+Guard' if 'dpo' in args.peft_path[-1] else 'SFT+Guard')
        results['Dataset'].append(os.path.basename(os.path.dirname(eval_dataset)))
        results['mean'].append(guard_cost_scores.mean().item())
        results['std'].append(guard_cost_scores.std().item())
        
    with open(os.path.join(args.output_dir, args.output_filename), 'w') as f:
        results = json.dump(results, f)
