import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

import numpy as np
import random

import sys
sys.path.insert(0, "/home/ubuntu/projects/conv_basis")

from transformers import GPT2Model, GPT2Tokenizer
import torch
from src.model_gpt2 import Conv_GPT2Model
from pdb import set_trace as pds

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # set seed

def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load pre-trained model and tokenizer
    model_name = "openai-community/gpt2"
    model = GPT2Model.from_pretrained(
        model_name,
        device_map="auto",
        attn_implementation="eager"
    )
    model = model.to(device)

    conv_model = Conv_GPT2Model.from_pretrained(
        model_name,
        output_attentions=False,
        device_map="auto",
        attn_implementation="eager"
    )
    conv_model = conv_model.to(device)

    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    # Load IMDB dataset
    # dataset = load_dataset("imdb", split="train[:1]")  # Using first 1000 examples for demonstration
    # Load IMDB dataset
    full_dataset = load_dataset("imdb", split="train")
    
    # Create a fixed subset of 1000 examples
    subset_indices = random.sample(range(len(full_dataset)), 5)
    dataset = full_dataset.select(subset_indices)


    # Display the input and label for each example
    # for i, example in enumerate(dataset):
    #     print(f"\nExample {i + 1}:")
    #     print(f"Label: {example['label']} (0 = Negative, 1 = Positive)")
    #     print("Text:")
    #     # Print the first 200 characters of the review, followed by '...' if truncated
    #     print(example['text'][:200] + ('...' if len(example['text']) > 200 else ''))
    #     print("-" * 80)  # Separator line

    # return

    # Tokenize function
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

    # Tokenize the dataset
    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

    # Create DataLoader
    dataloader = DataLoader(tokenized_dataset, batch_size=1, shuffle=False)

    # Process batches
    results = []
    for batch in tqdm(dataloader, desc="Processing batches"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            conv_outputs = conv_model(input_ids=input_ids, attention_mask=attention_mask,k=20)

        last_hidden_states = outputs.last_hidden_state
        
        conv_last_hidden_states = conv_outputs.last_hidden_state

        print(f"last_hidden_states: {last_hidden_states}")
        print(f"conv_last_hidden_states: {conv_last_hidden_states}")
        # continue

        # Compare results
        diff = torch.abs(last_hidden_states - conv_last_hidden_states).mean().item()
        results.append(diff)

        print(f"per sample difference between GPT2Model and Conv_GPT2Model outputs: {diff}")


    # Print average difference
    avg_diff = sum(results) / len(results)
    print(f"Average difference between GPT2Model and Conv_GPT2Model outputs: {avg_diff}")

if __name__ == "__main__":
    main()