from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
import pandas as pd
import numpy as np
import glob
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import re
import datetime
from tqdm import tqdm
import torch

print(f"CUDA available: {torch.cuda.is_available()}")

import sys
import yaml

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from utils.eval_metrics import is_valid_smiles, clogp, qed
from tdc.generation import MolGen

from dotenv import load_dotenv
from collections import OrderedDict
from concept_alignment.models.cbm import ConceptBottleneckModel


load_dotenv()

os.environ.get('HF_TOKEN')

torch.cuda.empty_cache()

script_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(script_dir, "configs/cav_eval.yaml")

if os.path.exists(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    modify_func_name = config.get("modify_func", "additive")  # Default to 'additive'
else:
    modify_func_name = "additive"

data = config.get("data")
model_name = config.get("model_name")
layer_to_modify = config.get("layer_to_modify")
# qed_alpha = config.get("qed_alpha")
# logp_alpha = config.get("logp_alpha")
cbm_model_path = config.get("cbm_model_path")

print("CAV Alignment Function: ", modify_func_name)

print("Eval Dataset: ", data)

if data == "zinc":
    sampled_data = pd.read_csv("PATH/test_data/zinc_eval.csv")
elif data == "moses":
    eval_data = pd.read_parquet("PATH/test_data/moses_eval.parquet")
    sampled_data = eval_data.sample(n=2000, random_state=21)
elif data == "chembl":
    eval_data = pd.read_parquet("PATH/test_data/chembl_eval.parquet")
    sampled_data = eval_data.sample(n=2000, random_state=21)
elif data == "oled":
    eval_data = pd.read_parquet("PATH/test_data/oled_dataset.parquet")
    sampled_data = eval_data.sample(n=100, random_state=21)

if model_name == "chemlactica-125m":
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemlactica-125m")
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemlactica-125m").eval().to('cuda')
elif model_name == "chemlactica-1p3b":
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemlactica-1.3b")
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemlactica-1.3b").eval().to('cuda')
elif model_name == "chemma-2b":
    tokenizer = AutoTokenizer.from_pretrained("yerevann/chemma-2b")
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("yerevann/chemma-2b").eval().to('cuda')
else:
    print("Model name not valid")


special_tokens = ['[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]', '[QED]', '[/QED]', '[LOGP]', '[/LOGP]', '[START_SMILES]', '[END_SMILES]', '[SEP]']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
model.resize_token_embeddings(len(tokenizer))


def load_cbm_model(model_path, device='cuda'):
    """Load trained CBM model from state dictionary"""
    # Load the state dictionary
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
    # Extract the model state dict based on checkpoint structure
    if isinstance(checkpoint, OrderedDict):
        state_dict = checkpoint
    elif isinstance(checkpoint, dict):
        if 'cbm_state_dict' in checkpoint:
            state_dict = checkpoint['cbm_state_dict']
        elif 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        elif 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint 
    else:
        return checkpoint
    
    # Get hidden dimension from the model
    hidden_dim = model.config.hidden_size
    
    # Create CBM with matching dimension
    cbm = ConceptBottleneckModel(
        hidden_dim=hidden_dim,
        num_concepts=2,  # Assuming QED and LOGP
        concept_dim=hidden_dim
    ).to(device)
    
    # Fix the state dict key mismatch
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('cbm.cbm.'):
            new_key = key.replace('cbm.cbm.', 'cbm.')
            new_state_dict[new_key] = value
        elif key.startswith('cbm.'):
            if 'concept_embeddings' in key:
                new_key = 'concept_embeddings'
                new_state_dict[new_key] = value
            else:
                new_key = key.replace('cbm.', '')
                new_state_dict[new_key] = value
        else:
            new_state_dict[key] = value
    
    cbm.load_state_dict(new_state_dict, strict=False)
    cbm.eval()
    
    return cbm

all_concept_values = []
all_concept_embeddings = []
all_steering_vectors = []
all_input_embeddings = []

cbm = load_cbm_model(cbm_model_path)

def create_steering_vector(cbm, h_n):
    with torch.no_grad():
        concept_values, concept_vectors = cbm(h_n)
        steering_vector = concept_vectors.sum(dim=1)

        all_concept_values.append(concept_values.cpu().numpy())  
        all_concept_embeddings.append(concept_vectors.cpu().numpy())
        all_steering_vectors.append(steering_vector.cpu().numpy()) 
        all_input_embeddings.append(h_n.cpu().numpy())

    return steering_vector


def modify_layer_output(module, input, output):
    """Additive steering with CBM - vector from last token applied to all tokens"""
    # Handle different output formats
    if isinstance(output, tuple):
        tensor_output = output[0]
    else:
        tensor_output = output
    
    # Get dimensions
    batch_size, seq_length, hidden_dim = tensor_output.shape
    
    # Get the last token representation for each sample in the batch
    last_token_idx = seq_length - 1
    last_token_reps = tensor_output[:, last_token_idx]  # [batch_size, hidden_dim]
    
    # Create steering vectors from only the last token representations
    steering_vectors = create_steering_vector(cbm, last_token_reps)  # [batch_size, hidden_dim]
    
    # Apply the same steering vector to all token positions
    alpha = 1  # Overall steering strength
    
    # Reshape steering_vectors to broadcast to all positions
    # [batch_size, 1, hidden_dim] for broadcasting to [batch_size, seq_length, hidden_dim]
    steering_vectors = steering_vectors.unsqueeze(1)
    
    # Apply steering to all positions
    modified_output = tensor_output + alpha * steering_vectors
    
    # Return modified output in the same format as the original
    if isinstance(output, tuple):
        return (modified_output,) + output[1:]
    else:
        return modified_output


MODIFY_FUNCTIONS = {
    "additive": modify_layer_output,
}

selected_modify_func = MODIFY_FUNCTIONS.get(modify_func_name, modify_layer_output)

if model_name == "chemma-2b":
    layer = model.model.layers[layer_to_modify]
else:
    layer = model.model.decoder.layers[layer_to_modify]

hook_handle = layer.register_forward_hook(selected_modify_func)

batch_size = 16
eval_set = []

input_texts = [f"</s>[QED]{row['QED']}[/QED][LOGP]{row['LOGP']}[/LOGP][START_SMILES]" for _, row in sampled_data.iterrows()]

# Process in batches
for i in tqdm(range(0, len(input_texts), batch_size), total=len(input_texts) // batch_size + 1):
    batch_inputs = input_texts[i:i + batch_size]

    # Tokenize entire batch
    prompt = tokenizer(batch_inputs, return_tensors="pt", padding=True, truncation=True).to(model.device)

    # Generate batch output
    output = model.generate(
        prompt.input_ids, 
        do_sample=False, 
        max_length=512, 
        return_dict_in_generate=True, 
        output_scores=True, 
        eos_token_id=tokenizer.convert_tokens_to_ids("[END_SMILES]")
    )

    # Decode outputs in batch
    batch_outputs = tokenizer.batch_decode(output.sequences)

    # Extract SMILES and evaluate properties
    for input_text, out in zip(batch_inputs, batch_outputs):
        smiles_pattern = re.search(r'\[START_SMILES\](.*?)(\[END_SMILES\]|$)', out)
        smiles_string = smiles_pattern.group(1) if smiles_pattern else None

        valid_smiles = is_valid_smiles(smiles_string)
        clogp_value = clogp(smiles_string)
        qed_value = qed(smiles_string)

        eval_set.append({
            'input_text': input_text,
            'output': out,
            'smiles_string': smiles_string,
            'valid_smiles': valid_smiles,
            'clogp_value': clogp_value,
            'qed_value': qed_value
        })

eval_df = pd.DataFrame(eval_set)

file_name = f"generated_smiles_{data}_cbm_{modify_func_name}_tdc_validation.csv"

eval_df.to_csv(f"PATH/concept_representation_alignment/results/{model_name}/{data}/{file_name}", index=False)
