from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
import pandas as pd
import numpy as np
import glob
import os
import re
import datetime
from tqdm import tqdm
import torch
import sys
import yaml

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from utils.eval_metrics import is_valid_smiles, clogp, qed
from tdc.generation import MolGen
from pandarallel import pandarallel
from dotenv import load_dotenv
from collections import OrderedDict
from concept_alignment.models.cbm import ConceptBottleneckModel


load_dotenv()

os.environ.get('HF_TOKEN')

torch.cuda.empty_cache()

pandarallel.initialize()

script_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(script_dir, "configs/cav_ood_logp.yaml")

if os.path.exists(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    modify_func_name = config.get("modify_func", "additive")  # Default to 'additive'
else:
    modify_func_name = "additive"

# data = config.get("data")
layer_to_modify = config.get("layer_to_modify")
cbm_model_path = config.get("cbm_model_path")
logp_prop_value = config.get("logp_value")
model_name = config.get("model_name")

print("CAV Alignment Function: ", modify_func_name)


# Load the tokenizer
if model_name == "chemlactica-125m":
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemlactica-125m")
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemlactica-125m").eval().to('cuda')
elif model_name == "chemlactica-1p3b":
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemlactica-1.3b")
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemlactica-1.3b").eval().to('cuda')
elif model_name == "chemma-2b":
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemma-2b")
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemma-2b").eval().to('cuda')

special_tokens = ['[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]', '[QED]', '[/QED]', '[LOGP]', '[/LOGP]', '[START_SMILES]', '[END_SMILES]', '[SEP]']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
model.resize_token_embeddings(len(tokenizer))

def load_cbm_model(model_path, device='cuda'):
    """Load trained CBM model from state dictionary"""
    # Load the state dictionary
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
    # Extract the model state dict based on checkpoint structure
    if isinstance(checkpoint, OrderedDict):
        state_dict = checkpoint
    elif isinstance(checkpoint, dict):
        # If it's inside a checkpoint dict with various possible keys
        if 'cbm_state_dict' in checkpoint:
            state_dict = checkpoint['cbm_state_dict']
        elif 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        elif 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint  # Use the whole dict if no specific keys found
    else:
        # If it's the full model
        return checkpoint
    
    # Get hidden dimension from the model
    hidden_dim = model.config.hidden_size
    
    # Create CBM with matching dimension
    cbm = ConceptBottleneckModel(
        hidden_dim=hidden_dim,
        num_concepts=2,  # Assuming QED and LOGP
        concept_dim=hidden_dim
    ).to(device)
    
    # Fix the state dict key mismatch
    new_state_dict = {}
    for key, value in state_dict.items():
        # Remove 'cbm.' prefix if it exists in the checkpoint keys
        if key.startswith('cbm.cbm.'):
            new_key = key.replace('cbm.cbm.', 'cbm.')
            new_state_dict[new_key] = value
        elif key.startswith('cbm.'):
            if 'concept_embeddings' in key:
                new_key = 'concept_embeddings'
                new_state_dict[new_key] = value
            else:
                new_key = key.replace('cbm.', '')
                new_state_dict[new_key] = value
        else:
            # Keep other keys as they are
            new_state_dict[key] = value
    
    # Try strict=False to ignore missing keys
    cbm.load_state_dict(new_state_dict, strict=False)
    cbm.eval()
    
    return cbm

cbm = load_cbm_model(cbm_model_path)

def create_steering_vector(cbm, h_n):
    with torch.no_grad():
        concept_values, concept_vectors = cbm(h_n)
        steering_vector = concept_vectors.sum(dim=1)
    return steering_vector


def modify_layer_output(module, input, output):
    """Additive steering with CBM - vector from last token applied to all tokens"""
    # Handle different output formats
    if isinstance(output, tuple):
        tensor_output = output[0]
    else:
        tensor_output = output
    
    # Get dimensions
    batch_size, seq_length, hidden_dim = tensor_output.shape
    
    # Get the last token representation for each sample in the batch
    last_token_idx = seq_length - 1
    last_token_reps = tensor_output[:, last_token_idx]  # [batch_size, hidden_dim]
    
    # Create steering vectors from only the last token representations
    steering_vectors = create_steering_vector(cbm, last_token_reps)  # [batch_size, hidden_dim]
    
    # Apply the same steering vector to all token positions
    alpha = 1  # Overall steering strength
    
    # Reshape steering_vectors to broadcast to all positions
    # [batch_size, 1, hidden_dim] for broadcasting to [batch_size, seq_length, hidden_dim]
    steering_vectors = steering_vectors.unsqueeze(1)
    
    # Apply steering to all positions
    modified_output = tensor_output + alpha * steering_vectors
    
    # Return modified output in the same format as the original
    if isinstance(output, tuple):
        return (modified_output,) + output[1:]
    else:
        return modified_output


MODIFY_FUNCTIONS = {
    "additive": modify_layer_output,
}

selected_modify_func = MODIFY_FUNCTIONS.get(modify_func_name, modify_layer_output)

if model_name == "chemma-2b":
    layer = model.model.layers[layer_to_modify]
else:
    layer = model.model.decoder.layers[layer_to_modify]

hook_handle = layer.register_forward_hook(selected_modify_func)

eval_set = []

input_text = f"</s>[LOGP]{logp_prop_value}[/LOGP][START_SMILES]"
prompt = tokenizer(input_text, return_tensors="pt").to(model.device)
# Generation settings
num_samples = 2000
batch_size = 16  

all_outputs = []

# Diversity parameters
temperature = 0.9 
top_k = 50         
top_p = 0.9        

# Generate in batches
num_batches = (num_samples + batch_size - 1) // batch_size  # Calculate number of batches
for start_idx in tqdm(range(0, num_samples, batch_size), total=num_batches, desc="Generating samples"):
    end_idx = min(start_idx + batch_size, num_samples)
    output = model.generate(
        prompt.input_ids.repeat(end_idx - start_idx, 1),
        do_sample=True,  
        temperature=temperature,
        max_length=512,
        return_dict_in_generate=True,
        eos_token_id=tokenizer.convert_tokens_to_ids("[END_SMILES]"),
    )
    # Decode batch outputs
    decoded_outputs = tokenizer.batch_decode(output.sequences, skip_special_tokens=False)

    for generated_output in decoded_outputs:

        # Extract SMILES string
        smiles_pattern = re.search(r'\[START_SMILES\](.*?)(\[END_SMILES\]|$)', generated_output)
        smiles_string = smiles_pattern.group(1) if smiles_pattern else None

        # Validate SMILES and calculate properties
        valid_smiles = is_valid_smiles(smiles_string) if smiles_string else False
        clogp_value = clogp(smiles_string) if valid_smiles else None
        qed_value = qed(smiles_string) if valid_smiles else None

        # Append to evaluation set
        eval_set.append({
            'input_text': input_text,
            'smiles_string': smiles_string,
            'valid_smiles': valid_smiles,
            'clogp_value': clogp_value,
            'qed_value': qed_value
        })


eval_df = pd.DataFrame(eval_set)

if logp_prop_value < 0:
    log_prop = "low"
else:
    log_prop = "high"

file_name = f"ood_logp_{log_prop}_cbm_steer_test.csv"

eval_df.to_csv(f"PATH/concept_representation_alignment/results/{model_name}/ood/logp/{file_name}", index=False)
print(f"Saved File at PATH/concept_representation_alignment/results/{model_name}/ood/logp/{file_name}")
