import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import torch
import numpy as np
import random
import sys
sys.path.insert(0, "/home/ubuntu/projects/conv_basis")
from transformers import LlamaModel, LlamaTokenizer, AutoTokenizer
from src.model_llama import Conv_LlamaModel
from pdb import set_trace as pds
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # set seed

def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load pre-trained model and tokenizer
    model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
    # model_name = "meta-llama/Llama-2-7b-hf"
    model = LlamaModel.from_pretrained(
        model_name,
        device_map="auto",
        attn_implementation="eager"
    )

    conv_model = Conv_LlamaModel.from_pretrained(
        model_name,
        output_attentions=False,
        # device_map="auto",
        device_map = "sequential",
        attn_implementation="eager"
    )

    # If you need to use dispatch_model after loading, you can do:
    # device_map = conv_model.hf_device_map  # This gets the device map created by from_pretrained
    # conv_model = dispatch_model(conv_model, device_map=device_map)


    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token


    # Load IMDB dataset
    # dataset = load_dataset("imdb", split="train[:1]")  # Using first 1000 examples for demonstration
    # Load IMDB dataset
    full_dataset = load_dataset("imdb", split="train")
    
    # Create a fixed subset of 1000 examples
    subset_indices = random.sample(range(len(full_dataset)), 5)
    dataset = full_dataset.select(subset_indices)


    # Display the input and label for each example
    # for i, example in enumerate(dataset):
    #     print(f"\nExample {i + 1}:")
    #     print(f"Label: {example['label']} (0 = Negative, 1 = Positive)")
    #     print("Text:")
    #     # Print the first 200 characters of the review, followed by '...' if truncated
    #     print(example['text'][:200] + ('...' if len(example['text']) > 200 else ''))
    #     print("-" * 80)  # Separator line

    # return

    # Tokenize function
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=2048)

    # Tokenize the dataset
    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

    # Create DataLoader
    dataloader = DataLoader(tokenized_dataset, batch_size=1, shuffle=False)

    # Process batches
    results = []
    for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        # pds()

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            conv_outputs = conv_model(input_ids=input_ids, attention_mask=attention_mask,k=32)
            # conv_outputs = conv_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_states = outputs.last_hidden_state
        
        conv_last_hidden_states = conv_outputs.last_hidden_state

        # Convert conv_last_hidden_states to numpy and save
        # conv_np = conv_last_hidden_states.cpu().numpy()
        # np.save(os.path.join(f'conv_hidden_states_batch_{batch_idx}.npy'), conv_np)

        # print(f"last_hidden_states: {last_hidden_states}")
        # print(f"conv_last_hidden_states: {conv_last_hidden_states}")
        # continue

        # Compare results
        diff = torch.abs(last_hidden_states - conv_last_hidden_states).mean().item()
        results.append(diff)

        print(f"per sample difference between GPT2Model and Conv_GPT2Model outputs: {diff}")


    # Print average difference
    avg_diff = sum(results) / len(results)
    print(f"Average difference between GPT2Model and Conv_GPT2Model outputs: {avg_diff}")

if __name__ == "__main__":
    main()