import sys
import torch
import numpy as np
from transformers import AutoTokenizer, RwkvForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
torch.manual_seed(42)

cache_dir_path = # FILL IN
data_folder = # FILL IN
save_folder = # FILL IN

def synthetic(text, num_kv_pairs=200):
    sequence = torch.ones(len(text), dtype=torch.long)
    subset = list(set(text))
    num_chosen = min(2*num_kv_pairs, len(subset) // 4)
    if num_chosen % 2 == 1:
        num_chosen -= 1
    selected = np.random.choice(subset, num_chosen, replace=False)
   
    # convert subset to list
    pairs = [(selected[j], selected[j+1])  for j in range(0, num_chosen, 2)]
    for j, tok in enumerate(selected):
        sequence[j] = tok

    # pick random spots to insert pairs after
    insert_spots = np.random.randint(num_chosen, len(text)-1, len(pairs))
    for j, spot in enumerate(insert_spots):
        sequence[spot] = pairs[j][0]
        sequence[spot+1] = pairs[j][1]
    return sequence.tolist() 


def run_valid(valid_data_iterator, args, model_str="RWKV/rwkv-raven-7b", is_synthetic=True, num_kv_pairs=4): 
    print(f"Loading models...")
    if 'world' in model_str:
        # https://huggingface.co/RWKV/rwkv-4-world-7b
        tokenizer = AutoTokenizer.from_pretrained(model_str, cache_dir=cache_dir_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_str, cache_dir=cache_dir_path)
    else:
        # https://huggingface.co/docs/transformers/model_doc/rwkv
        tokenizer = AutoTokenizer.from_pretrained(model_str, cache_dir=cache_dir_path)
        model = RwkvForCausalLM.from_pretrained(model_str,  cache_dir=cache_dir_path)
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.eval()
    model.to(device="cuda")

    print(f"Running in Synthetic Mode: {is_synthetic}")

    from data.tokenizer import build_tokenizer
    tokenizer_in = build_tokenizer(args)  

    num_samples = 1000
    seqlen = 1024

    input_tokens = []
    input_token_ids = []
    out_tokens = []
    out_token_ids = []
    pred_probs_list = []
    true_probs_list = []
    label_token_ids = []
    for batch_num, batch in enumerate(valid_data_iterator):
        print(f"Batch {batch_num} of {num_samples}")

        text = list(batch['text'])[:seqlen*2]
        if is_synthetic:
            # In synthetic mode construct MQAR synthetic sequences from the raw text.
            text = synthetic(text, num_kv_pairs=num_kv_pairs)
        decoded = tokenizer_in.tokenizer.decode(text)
        inputs = tokenizer(decoded, return_tensors="pt", truncation=True, max_length=seqlen).to(model.device) 
        outputs = model(**inputs, labels=inputs["input_ids"])
        logits = outputs.logits.to(device="cpu").detach()        # logits for the batch

        # softmax on logits
        logits = torch.nn.functional.softmax(logits, dim=-1)
        pred = torch.argmax(logits, dim=-1)     
        pred_str = tokenizer.batch_decode(pred)                  # list of preds by batch element

        # true tokens are inputs shifted by 1
        true_token_ids = inputs["input_ids"].to(device="cpu").detach().numpy()
        true_token_ids = true_token_ids[:,1:]
        true_probs = torch.gather(logits, 2, torch.tensor(true_token_ids).unsqueeze(-1)).squeeze(-1)
        true_token_ids = true_token_ids[0]

        # predicted tokens are logits shifted by 1
        pred_token_ids = pred.to(device="cpu").detach().numpy()
        pred_token_ids = pred_token_ids[:,:-1]
        pred_probs = torch.gather(logits, 2, torch.tensor(pred_token_ids).unsqueeze(-1)).squeeze(-1)
        pred_token_ids = pred_token_ids[0]
        
        # save values
        input_tokens.append(decoded)
        out_tokens.append(pred_str)
        input_token_ids.append(inputs["input_ids"].to(device="cpu").detach().numpy())
        out_token_ids.append(pred.to(device="cpu").detach().numpy())
        pred_probs_list.append(pred_probs.to(device="cpu").detach().numpy())
        true_probs_list.append(true_probs.to(device="cpu").detach().numpy())
        label_token_ids.append(true_token_ids)

        outputs.logits = outputs.logits.to(device="cpu")
        outputs.loss = outputs.loss.to(device="cpu")
        for state in outputs.state:
            state.to(device="cpu")
        logits.detach().cpu().numpy()

        if batch_num % 10 == 0:
            print(f"Saving data...")
            
            # construct dataframe
            import pandas as pd
            df = pd.DataFrame()
            df['input_tokens'] = input_tokens
            df['input_token_ids'] = input_token_ids
            df['out_tokens'] = out_tokens
            df['out_token_ids'] = out_token_ids
            df['pred_probs'] = pred_probs_list
            df['true_probs'] = true_probs_list
            df['label_ids'] = label_token_ids

            # save dataframe
            if "raven" in model_str:
                df.to_pickle(f"{save_folder}/valid_df_raven_L{seqlen}_synthetic{is_synthetic}_KVs{num_kv_pairs}.pkl")
            elif "world" in model_str:
                df.to_pickle(f"{save_folder}/valid_df_world_L{seqlen}_synthetic{is_synthetic}_KVs{num_kv_pairs}.pkl")
            else:
                df.to_pickle(f"{save_folder}/valid_df_L{seqlen}_synthetic{is_synthetic}_KVs{num_kv_pairs}.pkl")

        if batch_num > num_samples:
            break
    print(f"Done.")


def get_args():
    import argparse
    parser = argparse.ArgumentParser()
    group = parser.add_argument_group(title="tokenizer")
    group.add_argument(
        "--tokenizer-type",
        type=str,
        default="GPT2BPETokenizer",
        choices=[
            "HFGPT2Tokenizer",
            "HFTokenizer",
            "GPT2BPETokenizer",
            "CharLevelTokenizer",
            "TiktokenTokenizer",
        ],
        help="What type of tokenizer to use.",
    )
    group.add_argument(
        "--vocab-file", type=str, default=f"{data_folder}/gpt2-vocab.json", help="Path to the vocab file"
    )
    group.add_argument(
        "--merge-file",
        type=str,
        default=f"{data_folder}/gpt2-merges.txt",
        help="Path to the BPE merge file (if necessary).",
    )
    group.add_argument(
        "--append-eod",
        action="store_true",
        help="Append an <eod> token to the end of a document.",
    )

    # Dataset stuff
    group.add_argument("--train-iters",type=int,default=10,help="Number of iterations to train for.")
    group.add_argument("--eval-iters",type=int,default=10,help="Number of iterations to validate for.")
    group.add_argument("--test-iters",type=int,default=10,help="Number of iterations to test for.")
    group.add_argument("--train-batch-size",type=int,default=10,help="Batch size for training.")
    group.add_argument("--eval-interval",type=int,default=1,help="Number of iterations between evaluations.")

    # train_data_paths
    group.add_argument("--train_data_paths",type=str,default=f"{data_folder}/pile_text_document",help="Path to training data.")
    group.add_argument("--valid_data_paths",type=str,default=f"{data_folder}/pile_validation/pile_validation_text_document",help="Path to training data.")
    group.add_argument("--test_data_paths",type=str,default=f"{data_folder}/pile_test_text_document",help="Path to training data.")
    group.add_argument("--train_data_weights",default=[1],help="Path to training data.")
    group.add_argument("--valid_data_weights",default=[1],help="Path to training data.")
    group.add_argument("--test_data_weights",default=[1],help="Path to training data.")
    group.add_argument("--weight_by_num_documents",type=bool, default=False,help="Weight training data by number of documents.")

    group.add_argument("--is_pipe_parallel", default=True, type=bool, help="is pipe parallel")
    group.add_argument("--ftfy", action="store_true", help="Use ftfy to clean text")
    group = parser.add_argument_group(title="output data")
    args = parser.parse_args()
    args.keep_empty = False

    # some default/dummy values for the tokenizer
    args.rank = 0
    args.make_vocab_size_divisible_by = 128
    args.model_parallel_size = 1

    return args

# main
if __name__ == "__main__":    
    args = get_args()
    import sys
    sys.path.append("../")
    # Load the Pile dataset
    from data.data_utils import build_train_valid_test_data_iterators
    (
        train_data_iterator,
        valid_data_iterator,
        test_data_iterator,
    ) = build_train_valid_test_data_iterators(neox_args=args)

    run_valid(valid_data_iterator, args)
