import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.linear_model import LogisticRegression
from scipy.special import softmax
import os
import torch
import numpy as np
from tqdm import tqdm
from sklearn.utils import shuffle

from dotenv import load_dotenv

load_dotenv()
os.environ.get('HF_TOKEN')

property = "logp"

high_logp_data = pd.read_parquet("PATH/concept_representation_alignment/stimuli_dataset/logp/high_logp_2.5_tdc_data_new.parquet")
low_logp_data = pd.read_parquet("PATH/concept_representation_alignment/stimuli_dataset/logp/low_logp_2.5_tdc_data_new.parquet")

print("High LOGP Data Size: ", len(high_logp_data))
print("Low LOGP Data Size: ", len(low_logp_data))

high_logp_texts = high_logp_data['input_txt'].tolist()
low_logp_texts = low_logp_data['input_txt'].tolist()

high_logp_texts = high_logp_texts[:10000]
low_logp_texts = low_logp_texts[:10000]

model_name = "yerevann/chemlactica-125m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)

special_tokens = ['[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]', '[QED]', '[/QED]', '[LOGP]', '[/LOGP]', '[START_SMILES]', '[END_SMILES]', '[SEP]']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
model.resize_token_embeddings(len(tokenizer))

# num_layers = len(model.model.decoder.layers) 
# print(f"The model has {num_layers} layers.")

tokenizer.pad_token = tokenizer.eos_token

def extract_representations_opt(model, tokenizer, texts, layer, batch_size=8, device='cuda'):
    """
    Extracts representations from a specified layer of the OPT model,
    using only the last token's representation, processed in batches.
    
    Args:
        model: The OPTForCausalLM model.
        tokenizer: The tokenizer corresponding to the model.
        texts: A list of strings to process.
        layer: The layer index from which to extract representations.
        batch_size: Number of texts to process in each batch.
        device: The device to run the model on ('cuda' or 'cpu').
        
    Returns:
        A numpy array of shape (len(texts), hidden_size) containing the representations.
    """
    model.to(device)
    model.eval()
    activations = []
    
    # Validate input texts
    if not texts or not all(isinstance(text, str) and text.strip() for text in texts):
        raise ValueError("Input 'texts' must be a non-empty list of non-empty strings.")
    
    print(f"Processing {len(texts)} texts in batches of {batch_size}")
    
    with torch.no_grad():
        for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
            batch_texts = texts[i:i + batch_size]
            
            # Tokenize
            inputs = tokenizer(
                batch_texts, 
                return_tensors="pt", 
                truncation=True, 
                padding=True,
                max_length=512 
            )
            
            inputs = {key: value.to(device) for key, value in inputs.items() if key != "token_type_ids"}
            outputs = model(**inputs, output_hidden_states=True)
            
            if outputs.hidden_states is None:
                print(f"Warning: Model did not return hidden states for batch {i//batch_size + 1}.")
                continue
            
            hidden_states = outputs.hidden_states[layer]  # Shape: (batch_size, seq_len, hidden_size)
            
            # Get attention mask to find the last non-padding token for each sequence
            attention_mask = inputs["attention_mask"]
            
            # For each sequence in the batch, get the last token representation
            for j in range(hidden_states.size(0)):
                last_position = attention_mask[j].sum().item() - 1
                last_token_rep = hidden_states[j, last_position, :].cpu().numpy()
                activations.append(last_token_rep)
    
    if not activations:
        print("Warning: No activations were generated.")
        return np.array([])
    
    return np.array(activations)

def train_cav(typical_activations, atypical_activations):
    X = np.vstack([typical_activations, atypical_activations])
    y = np.array([0] * len(typical_activations) + [1] * len(atypical_activations))
    X, y = shuffle(X, y, random_state=42) 
    clf = LogisticRegression(max_iter=1000)
    clf.fit(X, y)
    return clf.coef_.flatten()  # CAV is the normal vector of the decision boundary

def compute_tcav_score(cav, activations):
    directional_derivatives = np.dot(activations, cav)
    # TCAV score is the fraction of positive directional derivatives
    mean_tcav_score = np.mean(directional_derivatives > 0)
    return directional_derivatives, mean_tcav_score

# Define layers to analyze
layers_to_test = [6]
tcav_scores_by_layer = {}
individual_tcav_scores_by_layer = {}
batch_size = 32
model.eval().to('cuda')

for layer in layers_to_test:
    print(f"Analyzing layer {layer}...")
    # Extract activations
    print("Extracting activations")

    high_logp_activations = extract_representations_opt(model, tokenizer, high_logp_texts, layer, batch_size)
    low_logp_activations = extract_representations_opt(model, tokenizer, low_logp_texts, layer, batch_size)

    # Train CAV
    print("Train CAV ...")
    cav = train_cav(high_logp_activations, low_logp_activations)
    cav_filename = f"PATH/cav/chemlactica-125m/{property}_layer_{layer}_tdc_cav_new.npy"

    print("Saving CAV")
    np.save(cav_filename, cav)
