import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import SGDClassifier
from scipy.special import softmax
import os
import torch
import numpy as np
from tqdm import tqdm
from sklearn.utils import shuffle
import matplotlib.pyplot as plt


from dotenv import load_dotenv

load_dotenv()
os.environ.get('HF_TOKEN')

property = "qed"

high_qed_data = pd.read_parquet("PATH/concept_representation_alignment/stimuli_dataset/qed/high_qed_0.8_tdc_data_new.parquet")
low_qed_data = pd.read_parquet("PATH/concept_representation_alignment/stimuli_dataset/qed/low_qed_0.8_tdc_data_new.parquet")

print("High QED Data Size: ", len(high_qed_data))
print("Low QED Data Size: ", len(low_qed_data))

high_qed_texts = high_qed_data['input_txt'].tolist()
low_qed_texts = low_qed_data['input_txt'].tolist()

high_qed_texts = high_qed_texts[:10000]
low_qed_texts = low_qed_texts[:10000]

model_name = "yerevann/chemlactica-1.3b"  
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)

special_tokens = ['[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]', '[QED]', '[/QED]', '[LOGP]', '[/LOGP]', '[START_SMILES]', '[END_SMILES]', '[SEP]']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
model.resize_token_embeddings(len(tokenizer))


print("Model configuration:", model.config)

tokenizer.pad_token = tokenizer.eos_token

def extract_representations_opt(model, tokenizer, texts, layer, batch_size=8, device='cuda'):
    """
    Extracts representations from a specified layer of the OPT model,
    using only the last token's representation, processed in batches.
    
    Args:
        model: The OPTForCausalLM model.
        tokenizer: The tokenizer corresponding to the model.
        texts: A list of strings to process.
        layer: The layer index from which to extract representations.
        batch_size: Number of texts to process in each batch.
        device: The device to run the model on ('cuda' or 'cpu').
        
    Returns:
        A numpy array of shape (len(texts), hidden_size) containing the representations.
    """
    model.to(device)
    model.eval()
    activations = []
    
    # Validate input texts
    if not texts or not all(isinstance(text, str) and text.strip() for text in texts):
        raise ValueError("Input 'texts' must be a non-empty list of non-empty strings.")
    
    print(f"Processing {len(texts)} texts in batches of {batch_size}")
    
    with torch.no_grad():
        # Process texts in batches
        for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
            batch_texts = texts[i:i + batch_size]
            
            # Tokenize the entire batch
            inputs = tokenizer(
                batch_texts, 
                return_tensors="pt", 
                truncation=True, 
                padding=True,
                max_length=512  # Add explicit max_length for consistency
            )
            
            # Move inputs to the specified device
            inputs = {key: value.to(device) for key, value in inputs.items() if key != "token_type_ids"}
            
            # Enable hidden states during the forward pass
            outputs = model(**inputs, output_hidden_states=True)
            
            # Validate hidden states
            if outputs.hidden_states is None:
                print(f"Warning: Model did not return hidden states for batch {i//batch_size + 1}.")
                continue
            
            # Extract hidden states for the specified layer
            hidden_states = outputs.hidden_states[layer]  # Shape: (batch_size, seq_len, hidden_size)
            
            # Get attention mask to find the last non-padding token for each sequence
            attention_mask = inputs["attention_mask"]
            
            # For each sequence in the batch, get the last token representation
            for j in range(hidden_states.size(0)):
                # Find the last non-padding token position
                last_position = attention_mask[j].sum().item() - 1
                last_token_rep = hidden_states[j, last_position, :].cpu().numpy()
                activations.append(last_token_rep)
    
    if not activations:
        print("Warning: No activations were generated.")
        return np.array([])
    
    return np.array(activations)

def train_cav(typical_activations, atypical_activations):
    X = np.vstack([typical_activations, atypical_activations])
    y = np.array([0] * len(typical_activations) + [1] * len(atypical_activations))
    X, y = shuffle(X, y, random_state=42) 
    clf = LogisticRegression(max_iter=1000)
    clf.fit(X, y)
    return clf.coef_.flatten()  # CAV is the normal vector of the decision boundary

def train_cav_sgd(high_qed_activations, low_qed_activations):
    X = np.vstack([high_qed_activations, low_qed_activations]).astype(np.float32)
    y = np.array([1]*len(high_qed_activations) + [0]*len(low_qed_activations))
    X, y = shuffle(X, y, random_state=42)
    clf = SGDClassifier(loss='log_loss', max_iter=10, tol=1e-3)  # play with max_iter
    clf.fit(X, y)

    cav = clf.coef_.flatten()
    return cav

def compute_tcav_score(cav, activations):
    directional_derivatives = np.dot(activations, cav)
    # TCAV score is the fraction of positive directional derivatives
    mean_tcav_score = np.mean(directional_derivatives > 0)
    return directional_derivatives, mean_tcav_score

# Define layers to analyze (e.g., 0 to 31 for a 32-layer model)
layers_to_test = [12]
tcav_scores_by_layer = {}
individual_tcav_scores_by_layer = {}
model.eval().to('cuda')

batch_size = 32

for layer in layers_to_test:
    print(f"Analyzing layer {layer}...")
    # Extract activations
    print("Extracting activations")

    high_qed_activations = extract_representations_opt(model, tokenizer, high_qed_texts, layer, batch_size)
    low_qed_activations = extract_representations_opt(model, tokenizer, low_qed_texts, layer, batch_size)

    # Train CAV
    print("Train CAV ...")
    cav = train_cav(high_qed_activations, low_qed_activations)

    directional_derivatives, mean_tcav_score = compute_tcav_score(cav, high_qed_activations)
    tcav_scores_by_layer[layer] = mean_tcav_score
    individual_tcav_scores_by_layer[layer] = directional_derivatives

    cav_filename = f"PATH/cav/chemlactica-1p3b/{property}_layer_{layer}_tdc_cav_new.npy"

    print("Saving CAV")
    np.save(cav_filename, cav)
    torch.cuda.empty_cache()

