import os
import argparse
from tqdm import tqdm
import torch
import datasets
from peft.peft_model import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

from hooked_models.HookedLlama import HookedLlamaForCausalLM
from hooked_models.utils import seed_torch
import transformer_lens.utils as utils
import ot
from eval.templates import create_prompt_with_tulu_chat_format
from eval.utils import load_hooked_lm_and_tokenizer

@torch.no_grad()
def generate_completions_and_masks(model, tokenizer, prompts, batch_size=1, add_special_tokens=True, disable_tqdm=False, **generation_kwargs):
    outputs = []
    attention_masks = []
    gather_masks = []
    if not disable_tqdm:
        progress = tqdm(total=len(prompts), desc="Generating Completions")

    num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i+batch_size]
        tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens)
        batch_input_ids = tokenized_prompts.input_ids
        attention_mask = tokenized_prompts.attention_mask

        if model.device.type == "cuda":
            batch_input_ids = batch_input_ids.to(model.device)
            attention_mask = attention_mask.to(model.device)

        try:
            batch_outputs_ids = model.generate(
                input_ids=batch_input_ids,
                attention_mask=attention_mask,
                **generation_kwargs
            )

            # remove the prompt from the output
            # we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
            # we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
            # space is important for some tasks (e.g., code completion).
            batch_outputs = tokenizer.batch_decode(batch_outputs_ids, skip_special_tokens=True)
            batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
            # duplicate the prompts to match the number of return sequences
            batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
            batch_generations = [
                output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
            ]
            # breakpoint()
            batch_ids = []
            batch_attention_mask = []
            batch_gather_mask = []
            max_length = -1
            for prompt, output in zip(batch_prompts, batch_generations):
                prompt_ids = tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids
                output_ids = tokenizer(output, add_special_tokens=False).input_ids
                ids = prompt_ids + output_ids
                max_length = max(len(ids), max_length)
                batch_ids.append(ids)
                batch_attention_mask.append([1]*len(ids))
                batch_gather_mask.append([1]*len(output_ids))
                
            batch_ids = [[tokenizer.pad_token_id]*(max_length-len(ids)) + ids for ids in batch_ids]
            batch_attention_mask = [[0]*(max_length-len(mask)) + mask for mask in batch_attention_mask]
            batch_gather_mask = [[0]*(max_length-len(mask)) + mask for mask in batch_gather_mask]
            
        except Exception as e:
            print("Error when generating completions for batch:")
            print(batch_prompts)
            print("Error message:")
            print(e)
            print("Use empty string as the completion.")
            batch_ids = batch_prompts * num_return_sequences

        outputs.append(torch.tensor(batch_ids))
        attention_masks.append(torch.tensor(batch_attention_mask))
        gather_masks.append(torch.tensor(batch_gather_mask))

        if not disable_tqdm:
            progress.update(len(batch_prompts)//num_return_sequences)

    # tokenized_generations = tokenizer(generations, padding="longest", return_tensors="pt", add_special_tokens=False)
    # assert len(outputs) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
    return outputs, attention_masks, gather_masks


def wasserstein_1d(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """ 1d wasserstein distance between two tensor

    Args:
        a (torch.Tensor): [batch, layer, neuron]
        b (torch.Tensor): [batch, layer, neuron]

    Returns:
        wasserstein distance: [layer, neuron]
    """
    dists = []
    size = a.shape
    for layer in tqdm(range(size[1]), desc="Computing Wasserstein Distance"):
        dists.append(ot.wasserstein_1d(a[:,layer,:], b[:,layer,:]).unsqueeze(0))
    return torch.cat(dists, dim=0)

def L2_distance(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """  distance between two tensor

    Args:
        a (torch.Tensor): base activation [batch, layer, neuron]
        b (torch.Tensor): lora activation [batch, layer, neuron]

    Returns:
        rms distance: [layer, neuron]
    """
    return (a - b).square().mean(0).sqrt()
    

def get_neuron_activation(cache, attention_mask, num_hidden_layers, last_token=False):
    cache.to('cpu')
    attention_mask = attention_mask.cpu()
    stack_cache = torch.stack([cache[f'post{i}'] for i in range(num_hidden_layers)], dim=-2)
    del cache
    torch.cuda.empty_cache()
    stack_cache = stack_cache.float()
    if last_token:
        return stack_cache[:, -1, ...]
    size = stack_cache.size()
    # breakpoint()
    # cache_filled = torch.masked_fill(stack_cache, ~attention_mask[..., None, None].bool(), 0)
    # cache_cpu = cache_filled.sum(1) / attention_mask.sum(1)[..., None, None]
    if attention_mask.shape[1] != size[1]:
        attention_mask = torch.cat([torch.zeros(attention_mask.shape[0], size[1]-attention_mask.shape[1]), attention_mask], dim=1)
    cache_select = torch.masked_select(stack_cache, attention_mask[..., None, None].bool()).reshape(-1, size[-2], size[-1])
    assert cache_select.dtype == torch.float32
    assert cache_select.shape[0] == attention_mask.sum()
    return cache_select

def topk_index(tensor, k):
    if k < 0:
        k = tensor.numel()
    index = tensor.flatten().topk(k=k)[1]
    shape = tensor.shape
    out = []
    for dim in reversed(shape):
        out.append((index % dim).tolist())
        index = index // dim
    return zip(*tuple(reversed(out)))

def batch_run_with_cache_on_prompt(prompts, batch_size, model, hooks, last_token=False, exclude_last_n=1):
    device = model.device
    activation = []
    for i in tqdm(range(0, len(prompts), batch_size), desc="Getting Activations"):
        batch_prompts = prompts[i: i+batch_size]
        tokenized_prompts = model.to_tokens(batch_prompts, device=device)
        attention_mask = tokenized_prompts.attention_mask
        _, cache = model.run_with_cache(**tokenized_prompts, names_filter=hooks)
        batch_activation = get_neuron_activation(cache, attention_mask, model.config.num_hidden_layers-exclude_last_n, last_token=last_token)
        activation.append(batch_activation)
    return torch.concat(activation, dim=0)

def batch_run_with_cache_on_completion(prompts, attention_masks, gather_masks, model, hooks, last_token=False, exclude_last_n=1):
    device = model.device
    activation = []
    for prompt, attention_mask, gather_mask in tqdm(zip(prompts, attention_masks, gather_masks), desc="Getting Activations", total=len(prompts)):
        prompt = prompt.to(device)
        attention_mask = attention_mask.to(device)
        gather_mask = gather_mask.to(device)
        _, cache = model.run_with_cache(input_ids=prompt, attention_mask=attention_mask, names_filter=hooks)
        batch_activation = get_neuron_activation(cache, gather_mask, model.config.num_hidden_layers-exclude_last_n, last_token=last_token)
        activation.append(batch_activation)
    return torch.concat(activation, dim=0)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=256,
        help="Number of samples to evaluate.",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=-1,
        help="Number of samples to evaluate.",
    )
    parser.add_argument(
        "--topk",
        type=int,
        default=31 * 11008,
        help="Number of top different neurons to ablate.",
    )
    parser.add_argument(
        "--exclude_last_n",
        type=int,
        default=1,
        help="Number of last layers not included in index.",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="",
        help="Dataset to evaluate.",
    )
    parser.add_argument(
        "--output_file",
        type=str, 
        default="hooked_llama/data/default.pt")
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the model to generate the predictions.",
    )
    parser.add_argument(
        "--tokenizer_name_or_path",
        type=str,
        default=None,
        help="If specified, we will load the tokenizer from here.",
    )
    parser.add_argument(
        "--first_peft_path", 
        nargs='+', 
        default=None, 
        help="The folder contains lora checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--second_peft_path", 
        nargs='+', 
        default=None, 
        help="The folder contains lora checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--eval_batch_size", 
        type=int, 
        default=1, 
        help="Batch size for evaluation."
    )
    parser.add_argument(
        "--dist_func", 
        type=str, 
        default='L2_distance', 
        help="Batch size for evaluation."
    )
    parser.add_argument(
        "--prompt_only", 
        action='store_true', 
        help="Get activation on prompt only."
    )
    parser.add_argument(
        "--last_token", 
        action='store_true', 
        help="Get activation on the last token."
    )
    parser.add_argument(
        "--load_in_8bit", 
        action='store_true', 
        help="Load model in 8bit quantization."
    )
    parser.add_argument(
        "--generation_startswith", 
        type=str, 
        default='', 
        help="The folder contains lora checkpoint saved with PeftModel.save_pretrained()."
    )
    parser.add_argument(
        "--cache_activation", 
        action='store_true', 
        help="If specified, cache neuron activation on dataset using a single model."
    )

    args = parser.parse_args()
    # breakpoint()
    seed_torch(42)
    dists = []
    eval_data = datasets.load_dataset('json', data_files=args.dataset)["train"]["prompt"]
    if args.num_samples > 0:
        eval_data = eval_data[:args.num_samples]
        
    prompts = []
    for example in eval_data:
        prompt = example.strip() 
        messages = [{"role": "user", "content": prompt}]
        prompt = create_prompt_with_tulu_chat_format(messages, add_bos=False)
        prompts.append(prompt+args.generation_startswith)

    names_filter = lambda name: name.endswith('hook_post')
    
    if not args.cache_activation:
        print(f"using {args.num_samples} samples, first PEFT module {args.first_peft_path}, second PEFT module {args.second_peft_path} used for completion")
        hooked_model, tokenizer = load_hooked_lm_and_tokenizer(
            model_name_or_path=args.model_name_or_path,
            tokenizer_name_or_path=args.tokenizer_name_or_path if args.tokenizer_name_or_path is not None else args.model_name_or_path,
            load_in_8bit=args.load_in_8bit,
            device_map="auto",
            peft_name_or_path=args.second_peft_path
        )
        hooked_model.set_tokenizer(tokenizer)
        if not args.prompt_only:
            completed_prompts, attention_masks, gather_masks = generate_completions_and_masks(hooked_model, tokenizer, prompts, batch_size=args.eval_batch_size, max_new_tokens=args.max_new_tokens, do_sample=False)
            for completed_prompt, attention_mask, gather_mask in zip(completed_prompts, attention_masks, gather_masks):
                for i in range(1):
                    print('-------')
                    print('Text: ', tokenizer.decode(completed_prompt[i][attention_mask[i].bool()]))
                    print('Output:', tokenizer.decode(completed_prompt[i][gather_mask[i].bool()]))
                    print('-------')
                break

            # peft_type = getattr(hooked_model, "peft_type", "BASE")
            print(f'Get {args.second_peft_path[-1] if args.second_peft_path else "base"} model neuron activation on completion...')
            peft_activation = batch_run_with_cache_on_completion(completed_prompts, attention_masks, gather_masks, hooked_model, names_filter, last_token=False, exclude_last_n=args.exclude_last_n)
            
            del hooked_model
            torch.cuda.empty_cache()
            
            hooked_model, tokenizer = load_hooked_lm_and_tokenizer(
                model_name_or_path=args.model_name_or_path,
                tokenizer_name_or_path=args.tokenizer_name_or_path if args.tokenizer_name_or_path is not None else args.model_name_or_path,
                load_in_8bit=args.load_in_8bit,
                device_map="auto",
                peft_name_or_path=args.first_peft_path
            )
            # hooked_model.disable_peft()
            # peft_type = getattr(hooked_model, "peft_type", "BASE")
            print(f'Get {args.first_peft_path[-1] if args.first_peft_path else "base"} model neuron activation on completion...')
            base_activation = batch_run_with_cache_on_completion(completed_prompts, attention_masks, gather_masks, hooked_model, names_filter, last_token=False, exclude_last_n=args.exclude_last_n)
                
        else:
            peft_type = getattr(hooked_model, "peft_type", "BASE")
            print(f'Get {args.second_peft_path[-1] if args.second_peft_path else "base"} model neuron activation on prompt...')
            peft_activation = batch_run_with_cache_on_prompt(prompts, args.eval_batch_size, hooked_model, names_filter, last_token=args.last_token, exclude_last_n=args.exclude_last_n)
            del hooked_model
            torch.cuda.empty_cache()
            hooked_model, tokenizer = load_hooked_lm_and_tokenizer(
                model_name_or_path=args.model_name_or_path,
                tokenizer_name_or_path=args.tokenizer_name_or_path if args.tokenizer_name_or_path is not None else args.model_name_or_path,
                load_in_8bit=args.load_in_8bit,
                device_map="auto",
            )
            hooked_model.set_tokenizer(tokenizer)
            print(f'Get {args.first_peft_path[-1] if args.first_peft_path else "base"} model neuron activation on prompt...')
            base_activation = batch_run_with_cache_on_prompt(prompts, args.eval_batch_size, hooked_model, names_filter, last_token=args.last_token, exclude_last_n=args.exclude_last_n)

        first_peft_num = len(args.first_peft_path) if args.first_peft_path else 0
        second_peft_num = len(args.second_peft_path) if args.second_peft_path else 0
        if first_peft_num > second_peft_num :
            base_activation, peft_activation = peft_activation, base_activation

        # w_dist = wasserstein_1d(base_activation, peft_activation)
        print(f'Get neuron activation on {base_activation.shape[0]} tokens')
        dist = globals()[args.dist_func](base_activation, peft_activation)
        dists.append(dist)          
        if len(dists) > 1:
            avg_dist = torch.cat(dists, dim=0).mean(0)   
        else:
            avg_dist = dists[0] 
        base_mean = base_activation.mean(0)
        peft_mean = peft_activation.mean(0)
        base_std = base_activation.std(0)
        peft_std = peft_activation.std(0)
        
        index_tensor = torch.cat([torch.tensor((i, j)).unsqueeze(0) for i, j in topk_index(dist, args.topk)], dim=0)
        # value_tensor = torch.cat([torch.tensor((base_mean[i,j], peft_mean[i,j])).unsqueeze(0) for i, j in topk_index(w_dist, topk)], dim=0)
        output_dir = os.path.dirname(args.output_file)
        os.makedirs(output_dir, exist_ok=True)
        torch.save((avg_dist.cpu(), index_tensor.cpu(), base_mean.cpu(), peft_mean.cpu(), base_std.cpu(), peft_std.cpu()), args.output_file)
    else:
        hooked_model, tokenizer = load_hooked_lm_and_tokenizer(
            model_name_or_path=args.model_name_or_path,
            tokenizer_name_or_path=args.tokenizer_name_or_path if args.tokenizer_name_or_path is not None else args.model_name_or_path,
            load_in_8bit=args.load_in_8bit,
            device_map="auto",
            peft_name_or_path=args.second_peft_path
        )
        hooked_model.set_tokenizer(tokenizer)
        peft_type = getattr(hooked_model, "peft_type", "BASE")
        if not args.prompt_only:
            completed_prompts, attention_masks, gather_masks = generate_completions_and_masks(hooked_model, tokenizer, prompts, batch_size=args.eval_batch_size, max_new_tokens=args.max_new_tokens, do_sample=False)
            print(f'Get {peft_type} model neuron activation on completion...')
            peft_activation = batch_run_with_cache_on_completion(completed_prompts, attention_masks, gather_masks, hooked_model, names_filter, last_token=False, exclude_last_n=args.exclude_last_n)
        else:
            print(f'Get {peft_type} model neuron activation on original prompt...')
            peft_activation = batch_run_with_cache_on_prompt(prompts, args.eval_batch_size, hooked_model, names_filter, last_token=args.last_token, exclude_last_n=args.exclude_last_n)

        print(f'Get neuron activation on {peft_activation.shape[0]} samples')
        output_dir = os.path.dirname(args.output_file)
        os.makedirs(output_dir, exist_ok=True)
        torch.save(peft_activation, args.output_file)
        