import os
import argparse
from collections import defaultdict, Counter
import random
from tqdm import tqdm
import torch
import datasets
from peft.peft_model import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

from hooked_models.HookedLlama import HookedLlamaForCausalLM
from hooked_models.utils import seed_torch
from eval.templates import create_prompt_with_tulu_chat_format
from eval.utils import load_hooked_lm_and_tokenizer, load_hf_score_lm_and_tokenizer

@torch.no_grad()
def generate_completions_and_masks(model, tokenizer, prompts, cost_model=None, cost_tokenizer=None, batch_size=1, add_special_tokens=True, disable_tqdm=False, **generation_kwargs):
    outputs = []
    attention_masks = []
    gather_masks = []
    scores = []
    if not disable_tqdm:
        progress = tqdm(total=len(prompts), desc="Generating Completions")

    num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i+batch_size]
        tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens)
        batch_input_ids = tokenized_prompts.input_ids
        attention_mask = tokenized_prompts.attention_mask

        if model.device.type == "cuda":
            batch_input_ids = batch_input_ids.to(model.device)
            attention_mask = attention_mask.to(model.device)

        try:
            batch_outputs_ids = model.generate(
                input_ids=batch_input_ids,
                attention_mask=attention_mask,
                **generation_kwargs
            )

            # remove the prompt from the output
            # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
            # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
            # space is important for some tasks (e.g., code completion).
            batch_outputs = tokenizer.batch_decode(batch_outputs_ids, skip_special_tokens=True)
            batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
            # duplicate the prompts to match the number of return sequences
            batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
            batch_generations = [
                output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
            ]
            if cost_model and cost_tokenizer:
                cost_tokenized_prompts = cost_tokenizer(batch_outputs, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens)
                cost_batch_input_ids = cost_tokenized_prompts.input_ids
                cost_attention_mask = cost_tokenized_prompts.attention_mask     
                if cost_model.device.type == "cuda":
                    cost_batch_input_ids = cost_batch_input_ids.to(cost_model.device)
                    cost_attention_mask = cost_attention_mask.to(cost_model.device)
                cost_scores = cost_model(cost_batch_input_ids, attention_mask=cost_attention_mask).end_scores.squeeze(dim=-1).tolist()
            # breakpoint()
            batch_ids = []
            batch_attention_mask = []
            batch_gather_mask = []
            max_length = -1
            for prompt, output in zip(batch_prompts, batch_generations):
                prompt_ids = tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids
                output_ids = tokenizer(output, add_special_tokens=False).input_ids
                ids = prompt_ids + output_ids
                max_length = max(len(ids), max_length)
                batch_ids.append(ids)
                batch_attention_mask.append([1]*len(ids))
                batch_gather_mask.append([1]*len(output_ids))
                
            batch_ids = [[tokenizer.pad_token_id]*(max_length-len(ids)) + ids for ids in batch_ids]
            batch_attention_mask = [[0]*(max_length-len(mask)) + mask for mask in batch_attention_mask]
            batch_gather_mask = [[0]*(max_length-len(mask)) + mask for mask in batch_gather_mask]
            
        except Exception as e:
            print("Error when generating completions for batch:")
            print(batch_prompts)
            print("Error message:")
            print(e)
            print("Use empty string as the completion.")
            batch_ids = batch_prompts * num_return_sequences
            cost_scores = [0] * len(batch_prompts) * num_return_sequences

        outputs.append(torch.tensor(batch_ids))
        attention_masks.append(torch.tensor(batch_attention_mask))
        gather_masks.append(torch.tensor(batch_gather_mask))
        scores.append(cost_scores)
        if not disable_tqdm:
            progress.update(len(batch_prompts)//num_return_sequences)

    # tokenized_generations = tokenizer(generations, padding="longest", return_tensors="pt", add_special_tokens=False)
    # assert len(outputs) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
    return outputs, scores, attention_masks, gather_masks

def get_neuron_activation(caches, attention_mask, last_token=False) -> list[torch.FloatTensor]:
    stack_caches = [cache.float().cpu() for cache in caches]
    if last_token:
        return [stack_cache[:, -1, ...] for stack_cache in stack_caches]
    size = stack_caches[0].size()
    # breakpoint()
    # cache_filled = torch.masked_fill(stack_cache, ~attention_mask[..., None, None].bool(), 0)
    # cache_cpu = cache_filled.sum(1) / attention_mask.sum(1)[..., None, None]
    attention_mask = attention_mask.cpu()
    cache_selects = [torch.masked_select(stack_cache, attention_mask[..., None, None].bool()).reshape(-1, size[-2], size[-1]) for stack_cache in stack_caches]
    assert cache_selects[0].dtype == torch.float32
    assert cache_selects[0].shape[0] == attention_mask.sum()
    return cache_selects


def batch_run_with_cache(prompts, attention_masks, gather_masks, model, hooks, last_token=False):
    device = model.device
    activation = []
    # total_length = 0
    for prompt, attention_mask, gather_mask in tqdm(zip(prompts, attention_masks, gather_masks), desc="Getting Predictions", total=len(prompts)):
        prompt = prompt.to(device)
        attention_mask = attention_mask.to(device)
        gather_mask = gather_mask.to(device)
        # breakpoint()
        _, cache = model.run_with_cache(input_ids=prompt, attention_mask=attention_mask, names_filter=hooks)
        batch_activation = get_neuron_activation(cache, gather_mask, model.config.num_hidden_layers-1, last_token=last_token)
        activation.append(batch_activation)
    return torch.concat(activation, dim=0)

def batch_run_with_cache_on_prompt(prompts, batch_size, model, hooks, indexes=None, last_token=False, exclude_last_n=1):
    device = model.device
    activation = [[] for _ in indexes]
    for i in tqdm(range(0, len(prompts), batch_size), desc="Getting Activations"):
        batch_prompts = prompts[i: i+batch_size]
        tokenized_prompts = model.to_tokens(batch_prompts, device=device)
        attention_mask = tokenized_prompts.attention_mask
        _, cache = model.run_with_cache(**tokenized_prompts, names_filter=hooks)
        if indexes:
            activtion_per_index = []
            for index in indexes:
                layer_activations = []
                for layer, neurons in index.items():
                    layer_cache = cache['post', layer].to('cuda:0')
                    neurons = torch.tensor(neurons).to('cuda:0')
                    layer_activations.append(layer_cache[..., neurons])
                activations = torch.concat(layer_activations, -1)
                activtion_per_index.append(activations)
        # else:
        #     activations = torch.stack([cache[f'post{i}'] for i in range(model.config.num_hidden_layers-exclude_last_n)], dim=-2)
        del cache
        torch.cuda.empty_cache()
        batch_activation = get_neuron_activation(activtion_per_index, attention_mask, last_token=last_token)
        for j in range(len(indexes)):
            activation[j].append(batch_activation[j])
    return [torch.concat(act, dim=0) for act in activation]

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
    parser.add_argument(
        "--num_samples",
        type=int,
        default=100,
        help="Number of samples to evaluate.",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="",
        help="Dataset to evaluate.",
    )
    parser.add_argument(
        "--output_dir",
        type=str, 
        default="hooked_llama/data"
    )
    parser.add_argument(
        "--output_filename",
        type=str, 
        default="activation"
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the model to generate the predictions.",
    )
    parser.add_argument(
        "--tokenizer_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the tokenizer from here.",
    )
    parser.add_argument(
        "--cost_model_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the model to generate cost scores for predictions.",
    )
    parser.add_argument(
        "--cost_tokenizer_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the tokenizer from here.",
    )
    parser.add_argument(
        "--peft_path", 
        type=str, 
        nargs='+',
        default=None, 
        help="The folder contains peft checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--eval_batch_size", 
        type=int, 
        default=1, 
        help="Batch size for evaluation."
    )
    parser.add_argument(
        "--load_in_8bit",
        action="store_true",
        help="Load model in 8bit mode, which will reduce memory and speed up inference.",
    )
    parser.add_argument(
        "--index_path", 
        type=str, 
        default=None, 
        help="The folder contains lora checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--topk",
        type=int,
        default=1,
        nargs='+',
        help="Number of top neurons to cache.",
    )
    parser.add_argument(
        "--use_random_neurons", 
        action="store_true",
        help="Use random neurons instead of found."
    )
    parser.add_argument(
        "--random_neurons_everywhere", 
        action="store_true",
        help="Use random neurons instead of found."
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        nargs='+',
        help="Random seed.",
    )
    parser.add_argument(
        "--last_layer_neurons", 
        action="store_true",
        help="Use random neurons in last layer."
    )

    args = parser.parse_args()


    eval_data = datasets.load_dataset('json', data_files=args.dataset)["train"]["prompt"]
    if args.num_samples < 0:
        args.num_samples = len(eval_data)
    print(f"using {args.num_samples} samples")

    prompts = []
    for example in eval_data[:args.num_samples]:
        prompt = example
        messages = [{"role": "user", "content": prompt}]
        prompt = create_prompt_with_tulu_chat_format(messages, add_bos=False)
        prompts.append(prompt)
    # breakpoint()
    model, tokenizer = load_hooked_lm_and_tokenizer(
        model_name_or_path=args.model_name_or_path,
        tokenizer_name_or_path=args.tokenizer_name_or_path if args.tokenizer_name_or_path is not None else args.model_name_or_path,
        load_in_8bit=args.load_in_8bit,
        device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
        peft_name_or_path=args.peft_path
    )
    model.set_tokenizer(tokenizer)
    names_filter = lambda name: name.endswith('hook_post')
    _, index, *_ = torch.load(args.index_path)
    indexes = []
    save_names = []
    for topk in args.topk:
        topk_index = index[:topk]
        if args.use_random_neurons:
            for seed in args.seed:
                seed_torch(seed)
                if args.last_layer_neurons:
                    topk_index = torch.tensor([[model.config.num_hidden_layers-1, neuron] for neuron in random.sample(range(model.config.intermediate_size), topk)])
                elif args.random_neurons_everywhere:
                    topk_index = torch.tensor([[neuron // model.config.intermediate_size, neuron % model.config.intermediate_size] for neuron in random.sample(range(model.config.intermediate_size * model.config.num_hidden_layers), topk)])
                else:
                    counts = Counter(topk_index[:, 0].tolist())
                    topk_index = []
                    for layer, num in counts.items():
                        neurons = random.sample(range(model.config.intermediate_size), num)
                        topk_index += [[layer, neuron] for neuron in neurons]
                    topk_index = torch.tensor(topk_index)
                indexes.append(topk_index)
                save_names.append((topk, seed))
        else:
            indexes.append(topk_index)
            save_names.append(topk)
    
    layers = [defaultdict(list) for _ in indexes]
    for i, topk_index in enumerate(indexes):
        for layer, idx in topk_index:
            layers[i][layer.item()].append(idx)
    activation = batch_run_with_cache_on_prompt(prompts, batch_size=args.eval_batch_size, model=model, hooks=names_filter, indexes=layers, last_token=True)
    
    cost_model, cost_tokenizer = load_hf_score_lm_and_tokenizer(
        model_name_or_path=args.cost_model_name_or_path,
        tokenizer_name_or_path=args.cost_tokenizer_name_or_path if args.cost_tokenizer_name_or_path is not None else args.cost_model_name_or_path,
        load_in_8bit=args.load_in_8bit,
        device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
    )
    completed_prompts, cost_scores, attention_masks, gather_masks = generate_completions_and_masks(model, tokenizer, prompts, cost_model=cost_model, cost_tokenizer=cost_tokenizer, batch_size=args.eval_batch_size, max_new_tokens=128, do_sample=False)
    completed_prompt_strs, cost_score_list = [], []
    for completed_prompt, cost_score in zip(completed_prompts, cost_scores):
        completed_prompt_strs += tokenizer.batch_decode(completed_prompt, skip_special_tokens=True)
        cost_score_list += cost_score
    target = (torch.tensor(cost_score_list) > 0).long()
    if args.use_random_neurons:
        for i, (topk, seed) in enumerate(save_names):
            torch.save((activation[i], target), f"{args.output_dir}/{args.output_filename}_top{topk}_seed{seed}.pt")
    else:
        for i, topk in enumerate(save_names):
            torch.save((activation[i], target), f"{args.output_dir}/{args.output_filename}_top{topk}.pt")