import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
import torch
import numpy as np
import random

from transformers import GPT2Model, GPT2Tokenizer
import torch
from src.model_gpt2 import Conv_GPT2Model
from pdb import set_trace as pds

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print(f"Using device: {device}")

import os
import sys
def ensure_path(path, early_exit = False):
    if os.path.exists(path):
        if early_exit:
            if input('{:s} exists, continue? ([y]/n): '.format(path)) == 'n':
                sys.exit(0)
    else:
        os.makedirs(path)

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # set seed
# Set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def main():
    args = parse_args()
    # sample_size = args.sample_size
    start_idx = args.start_idx
    end_idx = args.end_idx
    task = args.task
    model_name = args.model_name_or_path
    k = args.k

    if args.naive:
        model_class = GPT2Model
    else:
        model_class = Conv_GPT2Model
    
    # Load pre-trained model and tokenizer
    model = model_class.from_pretrained(
        model_name,
        output_attentions=False,
        device_map="auto",
        attn_implementation="eager"
    )
    model = model.to(device)

    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    
    # Load IMDB dataset
    full_dataset = load_dataset(task, split="train")
    
    # Create a fixed subset of sample_size examples
    # subset_indices = random.sample(range(len(full_dataset)), sample_size)
    # dataset = full_dataset.select(subset_indices)

    # Shuffle the dataset once
    shuffled_indices = list(range(len(full_dataset)))
    # print(shuffled_indices[:10])
    random.shuffle(shuffled_indices)
    # print(shuffled_indices[:10])
    # pds()
    subset_indices = shuffled_indices[start_idx:end_idx]
    dataset = full_dataset.select(subset_indices)

    # Tokenize function
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=1024)

    # Tokenize the dataset
    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

    # Create DataLoader
    dataloader = DataLoader(tokenized_dataset, batch_size=1, shuffle=False)

    #### forward pass
    # Process batches
    results = []
    for batch in tqdm(dataloader, desc="Processing batches"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        pds()

        with torch.no_grad():
            if args.naive:
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            else:
                outputs = model(input_ids=input_ids, attention_mask=attention_mask,k=k)

        last_hidden_states = outputs.last_hidden_state
        

        print(f"last_hidden_states: {last_hidden_states}")
        print(f"last_hidden_states.shape: {last_hidden_states.shape}")
        results.append(last_hidden_states)

        
    saved_states = torch.concat(results).detach().cpu().numpy()

    constant = 1024
    save_fold = f"seq_len{constant}"
    if args.naive:
        saved_name = f"last_hidden_naive_{start_idx}_{end_idx}.npy"
    else:
        saved_name = f"last_hidden_conv_k_{k}_{start_idx}_{end_idx}.npy"


    print(f"save saved_states to : out/{save_fold}/{saved_name}, shape {saved_states.shape}")
    ensure_path(f"out/{save_fold}")
    np.save(f"out/{save_fold}/{saved_name}", saved_states)
    # pds()



def parse_args():
    parser = argparse.ArgumentParser(description="text encoder on vision language model")
    parser.add_argument(
        '--task', help='nlp dataset', type = str, default='imdb',
    )

    # parser.add_argument(
    #     '--sample_size', help='number of samples to run infer', type = int, default=10,
    # )

    parser.add_argument(
        '--start_idx', help='start index', type = int, default=0,
    )

    parser.add_argument(
        '--end_idx', help='end index', type = int, default=10,
    )

    parser.add_argument(
        '--model_name_or_path', help='gpt2 pretrained weight', type = str, default="openai-community/gpt2",
    )

    parser.add_argument(
        '--naive', help='whether use naive attn', action="store_true", default=False,
    )

    parser.add_argument(
        '--k', help='number of basis functions for k-conv', type = int, default=5,
    )
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    main()