import sys
sys.path.insert(0, "/home/ubuntu/projects/conv_basis")

from transformers import GPT2Model, GPT2Tokenizer
import torch
from src.model_gpt2 import Conv_GPT2Model
from pdb import set_trace as pds

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def compare_tensors(last_hidden_states, conv_last_hidden_states):
    # Print device, dtype, and shape for both tensors
    print("last_hidden_states:")
    print(f"  Device: {last_hidden_states.device}")
    print(f"  Dtype: {last_hidden_states.dtype}")
    print(f"  Shape: {last_hidden_states.shape}")
    print(f"  values: {last_hidden_states}")
    
    print("\nconv_last_hidden_states:")
    print(f"  Device: {conv_last_hidden_states.device}")
    print(f"  Dtype: {conv_last_hidden_states.dtype}")
    print(f"  Shape: {conv_last_hidden_states.shape}")
    print(f"  values: {conv_last_hidden_states}")
    
    # Compare tensors
    if torch.equal(last_hidden_states, conv_last_hidden_states):
        print("\nThe tensors are equal.")
    else:
        print("\nThe tensors are not equal.")
        # Calculate Frobenius norm of the difference
        frobenius_norm = torch.norm(last_hidden_states - conv_last_hidden_states, p='fro')
        print(f"Frobenius norm of the difference: {frobenius_norm.item()}")
        relative_frobenius_norm = torch.norm(last_hidden_states - conv_last_hidden_states, p='fro') / torch.norm(last_hidden_states, p='fro') 
        print(f"relative_frobenius_norm: {relative_frobenius_norm.item()}")


def main():
    # Load pre-trained model and tokenizer
    model_name = "openai-community/gpt2"
    model = GPT2Model.from_pretrained(
        model_name,
        device_map="auto",
        attn_implementation = "eager"
        )
    model = model.to(device)

    conv_model = Conv_GPT2Model.from_pretrained(
        model_name,
        output_attentions=False,
        device_map="auto",
        attn_implementation = "eager"
        )
    conv_model = conv_model.to(device)

    tokenizer = GPT2Tokenizer.from_pretrained(model_name)

    # Example usage
    input_text = "The quick brown fox jumps over the lazy dog."

    # Tokenize input text
    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    
    # Generate output from the model
    with torch.no_grad():
        outputs = model(**inputs)
        conv_outputs = conv_model(**inputs)
    
    last_hidden_states = outputs.last_hidden_state
    conv_last_hidden_states = conv_outputs.last_hidden_state

    # print(f"last_hidden_states shape: {last_hidden_states.shape}")
    # print(f"conv_last_hidden_states shape: {conv_last_hidden_states.shape}")
    # last_hidden_states:
    # Device: cuda:0
    # Dtype: torch.float32
    # Shape: torch.Size([1, 10, 768])

    # print(f"conv_last_hidden_states: {conv_last_hidden_states}")


    compare_tensors(last_hidden_states, conv_last_hidden_states)


if __name__ == '__main__':
    main()