import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
import torch
import numpy as np
import random
import json

from transformers import LlamaForCausalLM, AutoTokenizer
from src.model_llama import Conv_LlamaForCausalLM
from pdb import set_trace as pds

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print(f"Using device: {device}")

import os
import sys
def ensure_path(path, early_exit = False):
    if os.path.exists(path):
        if early_exit:
            if input('{:s} exists, continue? ([y]/n): '.format(path)) == 'n':
                sys.exit(0)
    else:
        os.makedirs(path)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # set seed
# Set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def save_json(data, file_path, indent = 4):
    print(f"save to {file_path}, with length {len(data)}")
    with open(file_path, 'w') as file:
        json.dump(data, file, indent = indent)

# Function to load JSON data from a file
def load_json(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)

def main():
    args = parse_args()
    model_name = args.model_name_or_path
    k = args.k

    if args.naive:
        model_class = LlamaForCausalLM
        device_map = "auto"
    else:
        model_class = Conv_LlamaForCausalLM
        device_map = "sequential"
    
    # Load pre-trained model and tokenizer
    model = model_class.from_pretrained(
        model_name,
        output_attentions=False,
        device_map=device_map,
        attn_implementation="eager"
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

    test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

    # Take a subset of the data
    subset_size = 1000  # You can adjust this number
    test_subset = test.select(range(subset_size))


    encodings = tokenizer("\n\n".join(test_subset["text"]), return_tensors="pt")

    # encodings_subset = tokenizer("\n\n".join(test_subset["text"]), return_tensors="pt")

    

    #### forward pass
    # Process batches
    max_length = args.max_length
    stride = 512
    seq_len = encodings.input_ids.size(1)

    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            # outputs = model(input_ids, labels=target_ids)
            model_input = {
            "input_ids": input_ids,
            "labels": target_ids,
            }
            if not args.naive:
                model_input["k"] = k
            outputs = model(**model_input)

            # loss is calculated using CrossEntropyLoss which averages over valid labels
            # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
            # to the left by 1.
            neg_log_likelihood = outputs.loss

        nlls.append(neg_log_likelihood)

        prev_end_loc = end_loc

        # Clear CUDA cache to free up memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        
        if end_loc == seq_len:
            break

    ppl = torch.exp(torch.stack(nlls).mean())
    print(f"max length: {max_length} | naive: {args.naive} | k: {k} | ppl: {ppl}")

    write_path = f"out/llama3_8b_ins/seq_len{max_length}"
    ensure_path(write_path)

    output_data = {
        "max_length": max_length,
        "naive": args.naive,
        "k": k,
        "ppl": ppl.item()
    }

    with open(f'{write_path}/perplex_output.jsonl', 'a') as f:
        json.dump(output_data, f)
        f.write('\n')



def parse_args():
    parser = argparse.ArgumentParser(description="text encoder on vision language model")
    parser.add_argument(
        '--task', help='nlp dataset', type = str, default='imdb',
    )

    # parser.add_argument(
    #     '--sample_size', help='number of samples to run infer', type = int, default=10,
    # )

    parser.add_argument(
        '--start_idx', help='start index', type = int, default=0,
    )

    parser.add_argument(
        '--end_idx', help='end index', type = int, default=10,
    )

    parser.add_argument(
        '--model_name_or_path', help='llama pretrained weight', type = str, default="meta-llama/Meta-Llama-3-8B-Instruct",
    )

    parser.add_argument(
        '--naive', help='whether use naive attn', action="store_true", default=False,
    )

    parser.add_argument(
        '--k', help='number of basis functions for k-conv', type = int, default=5,
    )

    parser.add_argument(
        '--max_length', help='max seq len', type = int, default=2048, # 8192 will OOM
    )

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    main()



