# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import fire
from llama import Llama
from typing import List
import pandas as pd
import torch
import torch.nn.functional as F
import numpy as np

data_dir = # FILL IN 
model_dir = # FILL IN
save_dir = # FILL IN

def get_args():
    import argparse
    parser = argparse.ArgumentParser()
    group = parser.add_argument_group(title="tokenizer")
    group.add_argument(
        "--tokenizer-type",
        type=str,
        default="GPT2BPETokenizer",
        choices=[
            "HFGPT2Tokenizer",
            "HFTokenizer",
            "GPT2BPETokenizer",
            "CharLevelTokenizer",
            "TiktokenTokenizer",
        ],
        help="What type of tokenizer to use.",
    )
    group.add_argument(
        "--vocab-file", type=str, default=f"{data_dir}/gpt2-vocab.json", help="Path to the vocab file"
    )
    group.add_argument(
        "--merge-file",
        type=str,
        default=f"{data_dir}/gpt2-merges.txt",
        help="Path to the BPE merge file (if necessary).",
    )
    group.add_argument(
        "--append-eod",
        action="store_true",
        help="Append an <eod> token to the end of a document.",
    )

    # Dataset stuff
    group.add_argument("--train-iters",type=int,default=10,help="Number of iterations to train for.")
    group.add_argument("--eval-iters",type=int,default=10,help="Number of iterations to validate for.")
    group.add_argument("--test-iters",type=int,default=10,help="Number of iterations to test for.")
    group.add_argument("--train-batch-size",type=int,default=10,help="Batch size for training.")
    group.add_argument("--eval-interval",type=int,default=1,help="Number of iterations between evaluations.")

    # train_data_paths
    group.add_argument("--train_data_paths",type=str,default=f"{data_dir}/pile_text_document",help="Path to training data.")
    group.add_argument("--valid_data_paths",type=str,default=f"{data_dir}/pile_validation/pile_validation_text_document",help="Path to training data.")
    group.add_argument("--test_data_paths",type=str,default=f"{data_dir}/pile_test/pile_test_text_document",help="Path to training data.")
    group.add_argument("--train_data_weights",default=[1],help="Path to training data.")
    group.add_argument("--valid_data_weights",default=[1],help="Path to training data.")
    group.add_argument("--test_data_weights",default=[1],help="Path to training data.")
    group.add_argument("--weight_by_num_documents",type=bool, default=False,help="Weight training data by number of documents.")

    group.add_argument("--is_pipe_parallel", default=True, type=bool, help="is pipe parallel")
    group.add_argument("--ftfy", action="store_true", help="Use ftfy to clean text")

    # add checkpoint dir stuff
    group.add_argument("--ckpt_dir", default=f'{model_dir}/llama/llama-2-7b/', help="Checkpoint activations to save memory")
    group.add_argument("--tokenizer_path", default='tokenizer.model', help="Checkpoint activations to save memory")

    group = parser.add_argument_group(title="output data")
    args = parser.parse_args()
    args.keep_empty = False

    # some default/dummy values for the tokenizer
    args.rank = 0
    args.make_vocab_size_divisible_by = 128
    args.model_parallel_size = 1

    return args


def generate(prompt_tokens_lst, tokenizer, model, bsz=1, total_len=1024):
    pad_id = tokenizer.pad_id
    tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
    for k, t in enumerate(prompt_tokens_lst):
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
    prev_pos = 0
    logits = model.forward(tokens, prev_pos)
    probs = F.softmax(logits, dim=-1)
    preds = torch.argmax(probs, dim=-1)
    return probs, preds


def synthetic(text, num_kv_pairs=10):
    sequence = torch.ones(len(text), dtype=torch.long)
    subset = list(set(text))
    num_chosen = min(2*num_kv_pairs, len(subset) // 4)
    if num_chosen % 2 == 1:
        num_chosen -= 1
    selected = np.random.choice(subset, num_chosen, replace=False)
   
    # convert subset to list
    pairs = [(selected[j], selected[j+1])  for j in range(0, num_chosen, 2)]
    for j, tok in enumerate(selected):
        sequence[j] = tok

    # pick random spots to insert pairs after
    insert_spots = np.random.randint(len(selected), len(text)-1, len(pairs))
    for j, spot in enumerate(insert_spots):
        sequence[spot] = pairs[j][0]
        sequence[spot+1] = pairs[j][1]
    return sequence.tolist() 


def main(
    ckpt_dir: str=f'{model_dir}/llama/llama-2-7b/',
    tokenizer_path: str = 'tokenizer.model',
    max_seq_len: int = 1024,
    max_batch_size: int = 4,
    is_synthetic: bool = True,
    num_kv_pairs: int = ModuleNotFoundError,
):
    """
    Entry point of the program for generating text using a pretrained model.

    Args:
        ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
        tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
        max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128.
    """  
    args = get_args()
    import sys

    # Load the Pile data
    sys.path.append("../")
    from data.data_utils import build_train_valid_test_data_iterators
    (
        train_data_iterator,
        valid_data_iterator,
        test_data_iterator,
    ) = build_train_valid_test_data_iterators(neox_args=args)

    print(f"Running in Synthetic Mode: {is_synthetic}")

    generator = Llama.build(
        ckpt_dir=ckpt_dir,
        tokenizer_path=tokenizer_path,
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
    )
    model, tokenizer = generator.model, generator.tokenizer

    from data.tokenizer import build_tokenizer
    tokenizer_in = build_tokenizer(args)  

    input_tokens = []
    input_token_ids = []
    out_tokens = []
    out_token_ids = []
    pred_probs_list = []
    true_probs_list = []
    label_token_ids = []
    num_samples = 1000
    for batch_num, batch in enumerate(valid_data_iterator):
        print(f"Batch {batch_num} of {num_samples}")

        text = list(batch['text'])[:max_seq_len*2]
        if is_synthetic:
            text = synthetic(text, num_kv_pairs=num_kv_pairs)
        decoded = tokenizer_in.tokenizer.decode(text)
        input_ids = tokenizer.encode(decoded, eos=False, bos=False)
        input_ids = input_ids[:max_seq_len]
        if len(input_ids) < max_seq_len:
            continue
        logits, pred = generate([input_ids], tokenizer, model, bsz=1, total_len=max_seq_len)
        pred_str = tokenizer.decode(pred.tolist())   

        # true tokens are inputs shifted by 1
        true_token_ids = torch.tensor([input_ids[1:]])
        true_probs = torch.gather(logits, 2, torch.tensor(true_token_ids).unsqueeze(-1)).squeeze(-1)
        true_token_ids = true_token_ids[0]

        # predicted tokens are logits shifted by 1
        pred_token_ids = pred.to(device="cpu").detach().numpy()
        pred_token_ids = pred_token_ids[:,:-1]  # consider unshifting this
        pred_probs = torch.gather(logits, 2, torch.tensor(pred_token_ids).unsqueeze(-1)).squeeze(-1)
        
        # save values
        input_tokens.append(decoded)
        input_token_ids.append(true_token_ids.to(device="cpu").numpy())
        out_tokens.append(pred_str)
        out_token_ids.append(pred_token_ids[0])
        label_token_ids.append(true_token_ids.to(device="cpu").detach().numpy())
        pred_probs_list.append(pred_probs.to(device="cpu").detach().numpy())
        true_probs_list.append(true_probs.to(device="cpu").detach().numpy())

        if batch_num % 10 == 0:
            print(f"Saving data...")
        
            df = pd.DataFrame()
            df['input_tokens'] = input_tokens
            df['input_token_ids'] = input_token_ids
            df['label_ids'] = label_token_ids
            df['out_tokens'] = out_tokens
            df['out_token_ids'] = out_token_ids
            df['pred_probs'] = pred_probs_list
            df['true_probs'] = true_probs_list

            # save dataframe
            df.to_pickle(f"{save_dir}/llama_2_7b_valid_df_L{max_seq_len}_synthetic{is_synthetic}_KVs{num_kv_pairs}.pkl")

        if batch_num > num_samples:
            break

if __name__ == "__main__":
    main()
