from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
import pandas as pd
import numpy as np
import glob
import os
import re
import datetime
from tqdm import tqdm
import torch
import sys
import yaml

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from utils.eval_metrics import is_valid_smiles, clogp, qed
from collections import OrderedDict
from concept_alignment.models.cbm import ConceptBottleneckModel

torch.cuda.empty_cache()

script_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(script_dir, "configs/cav_ood_qed.yaml")

if os.path.exists(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    modify_func_name = config.get("modify_func", "additive") 
else:
    modify_func_name = "additive"

# data = config.get("data")
model_name = config.get("model_name")
layer_to_modify = config.get("layer_to_modify")
cbm_model_path = config.get("cbm_model_path")
qed_prop_value = config.get("qed_value")

print("CAV Alignment Function: ", modify_func_name)

if model_name == "chemlactica-125m":
    tokenizer = AutoTokenizer.from_pretrained("PATH/models/chemlactica-125m")
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("PATH/models/chemlactica-125m").eval().to('cuda')
elif model_name == "chemlactica-1.3b":
    tokenizer = AutoTokenizer.from_pretrained("PATH/models/chemlactica-1.3b")
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("PATH/models/chemlactica-1.3b").eval().to('cuda')
elif model_name == "chemma-2b":
    tokenizer = AutoTokenizer.from_pretrained("PATH/models/chemma-2b")
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("PATH/models/chemma-2b").eval().to('cuda')

special_tokens = ['[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]', '[QED]', '[/QED]', '[LOGP]', '[/LOGP]', '[START_SMILES]', '[END_SMILES]', '[SEP]']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
model.resize_token_embeddings(len(tokenizer))


def load_cbm_model(model_path, device='cuda'):
    """Load trained CBM model from state dictionary"""
    # Load the state dictionary
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
    # Extract the model state dict based on checkpoint structure
    if isinstance(checkpoint, OrderedDict):
        state_dict = checkpoint
    elif isinstance(checkpoint, dict):

        if 'cbm_state_dict' in checkpoint:
            state_dict = checkpoint['cbm_state_dict']
        elif 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        elif 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint  
    else:
        return checkpoint
    
    # Get hidden dimension from the model
    hidden_dim = model.config.hidden_size
    
    # Create CBM with matching dimension
    cbm = ConceptBottleneckModel(
        hidden_dim=hidden_dim,
        num_concepts=2,  # Assuming QED and LOGP
        concept_dim=hidden_dim
    ).to(device)
    
    # Fix the state dict key mismatch
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('cbm.cbm.'):
            new_key = key.replace('cbm.cbm.', 'cbm.')
            new_state_dict[new_key] = value
        elif key.startswith('cbm.'):
            if 'concept_embeddings' in key:
                new_key = 'concept_embeddings'
                new_state_dict[new_key] = value
            else:
                new_key = key.replace('cbm.', '')
                new_state_dict[new_key] = value
        else:
            new_state_dict[key] = value
    
    cbm.load_state_dict(new_state_dict, strict=False)
    cbm.eval()
    
    return cbm

cbm = load_cbm_model(cbm_model_path)

def create_steering_vector(cbm, h_n):
    with torch.no_grad():
        concept_values, concept_vectors = cbm(h_n)
        steering_vector = concept_vectors.sum(dim=1)
    return steering_vector


def modify_layer_output(module, input, output):
    """Additive steering with CBM - vector from last token applied to all tokens"""
    if isinstance(output, tuple):
        tensor_output = output[0]
    else:
        tensor_output = output
    
    # Get dimensions
    batch_size, seq_length, hidden_dim = tensor_output.shape
    
    # Get the last token representation for each sample in the batch
    last_token_idx = seq_length - 1
    last_token_reps = tensor_output[:, last_token_idx]  # [batch_size, hidden_dim]
    
    # Create steering vectors from only the last token representations
    steering_vectors = create_steering_vector(cbm, last_token_reps)  # [batch_size, hidden_dim]
    
    # Apply the same steering vector to all token positions
    alpha = 1  # Overall steering strength
    
    # Reshape steering_vectors to broadcast to all positions
    # [batch_size, 1, hidden_dim] for broadcasting to [batch_size, seq_length, hidden_dim]
    steering_vectors = steering_vectors.unsqueeze(1)
    
    # Apply steering to all positions
    modified_output = tensor_output + alpha * steering_vectors
    
    # Return modified output in the same format as the original
    if isinstance(output, tuple):
        return (modified_output,) + output[1:]
    else:
        return modified_output


MODIFY_FUNCTIONS = {
    "additive": modify_layer_output,
}

selected_modify_func = MODIFY_FUNCTIONS.get(modify_func_name, modify_layer_output)

if model_name == "chemma-2b":
    layer = model.model.layers[layer_to_modify]
else:
    layer = model.model.decoder.layers[layer_to_modify]
    
hook_handle = layer.register_forward_hook(selected_modify_func)

batch_size = 64
eval_set = []

input_text = f"</s>[QED]{qed_prop_value}[/QED][LOGP]-3.2810[/LOGP][START_SMILES]"
prompt = tokenizer(input_text, return_tensors="pt").to(model.device)
# Generation settings
num_samples = 2000
all_outputs = []

temperature = 0.9 
top_k = 50         
top_p = 0.9        

# Generate in batches
num_batches = (num_samples + batch_size - 1) // batch_size 
for start_idx in tqdm(range(0, num_samples, batch_size), total=num_batches, desc="Generating samples"):
    end_idx = min(start_idx + batch_size, num_samples)
    output = model.generate(
        prompt.input_ids.repeat(end_idx - start_idx, 1),
        do_sample=True,  
        temperature=temperature,
        max_length=512,
        return_dict_in_generate=True,
        eos_token_id=tokenizer.convert_tokens_to_ids("[END_SMILES]"),
    )
    # Decode batch outputs
    decoded_outputs = tokenizer.batch_decode(output.sequences, skip_special_tokens=False)

    for generated_output in decoded_outputs:

        # Extract SMILES string
        smiles_pattern = re.search(r'\[START_SMILES\](.*?)(\[END_SMILES\]|$)', generated_output)
        smiles_string = smiles_pattern.group(1) if smiles_pattern else None

        # Validate SMILES and calculate properties
        valid_smiles = is_valid_smiles(smiles_string) if smiles_string else False
        clogp_value = clogp(smiles_string) if valid_smiles else None
        qed_value = qed(smiles_string) if valid_smiles else None

        # Append to evaluation set
        eval_set.append({
            'input_text': input_text,
            'smiles_string': smiles_string,
            'valid_smiles': valid_smiles,
            'clogp_value': clogp_value,
            'qed_value': qed_value
        })


eval_df = pd.DataFrame(eval_set)

if qed_prop_value < 0.5:
    qed_prop = "low"
else:
    qed_prop = "high"

current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"ood_qed_{qed_prop}_cbm_logp_low.csv"

eval_df.to_csv(f"PATH/concept_representation_alignment/results/{model_name}/ood/qed/{file_name}", index=False)
print(f"Saved File at PATH/concept_representation_alignment/results/{model_name}/ood/qed/{file_name}")
