from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
import pandas as pd
import numpy as np
import glob
import os
import re
import datetime
from tqdm import tqdm
import torch
import sys
import yaml

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from utils.eval_metrics import is_valid_smiles, clogp, qed
from collections import OrderedDict
from concept_alignment.models.cbm import ConceptBottleneckModel

torch.cuda.empty_cache()


modify_func_name = "additive"

def compute_qed_mae(df, ground_truth):
    df = df[df['valid_smiles'] == 1]
    return abs(ground_truth['QED'] - df['qed_value']).mean()

def compute_logp_mae(df, ground_truth):
    df = df[df['valid_smiles'] == 1]
    return abs(ground_truth['LOGP'] - df['clogp_value']).mean()

def validity(df):
    return df['valid_smiles'].mean()

def compute_generative_efficiency(df):
    unique_valid = df[df['valid_smiles'] == 1]['smiles_string'].nunique()
    return unique_valid / len(df) if len(df) > 0 else 0

def load_cbm_model(model_path, device='cuda'):
    """Load trained CBM model from state dictionary"""
    # Load the state dictionary
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
    # Extract the model state dict based on checkpoint structure
    if isinstance(checkpoint, OrderedDict):
        state_dict = checkpoint
    elif isinstance(checkpoint, dict):
        # If it's inside a checkpoint dict with various possible keys
        if 'cbm_state_dict' in checkpoint:
            state_dict = checkpoint['cbm_state_dict']
        elif 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        elif 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint  # Use the whole dict if no specific keys found
    else:
        # If it's the full model
        return checkpoint
    
    # Get hidden dimension from the model
    hidden_dim = model.config.hidden_size
    
    # Create CBM with matching dimension
    cbm = ConceptBottleneckModel(
        hidden_dim=hidden_dim,
        num_concepts=2,  # Assuming QED and LOGP
        concept_dim=hidden_dim
    ).to(device)
    
    # Fix the state dict key mismatch
    new_state_dict = {}
    for key, value in state_dict.items():
        # Remove 'cbm.' prefix if it exists in the checkpoint keys
        if key.startswith('cbm.cbm.'):
            new_key = key.replace('cbm.cbm.', 'cbm.')
            new_state_dict[new_key] = value
        elif key.startswith('cbm.'):
            if 'concept_embeddings' in key:
                new_key = 'concept_embeddings'
                new_state_dict[new_key] = value
            else:
                new_key = key.replace('cbm.', '')
                new_state_dict[new_key] = value
        else:
            # Keep other keys as they are
            new_state_dict[key] = value
    
    # Try strict=False to ignore missing keys
    cbm.load_state_dict(new_state_dict, strict=False)
    cbm.eval()
    
    return cbm

def create_steering_vector(cbm, h_n):
    with torch.no_grad():
        concept_values, concept_vectors = cbm(h_n)
        steering_vector = concept_vectors.sum(dim=1)
    return steering_vector


def modify_layer_output(module, input, output):
    """Additive steering with CBM - vector from last token applied to all tokens"""
    # Handle different output formats
    if isinstance(output, tuple):
        tensor_output = output[0]
    else:
        tensor_output = output
    
    # Get dimensions
    batch_size, seq_length, hidden_dim = tensor_output.shape
    
    # Get the last token representation for each sample in the batch
    last_token_idx = seq_length - 1
    last_token_reps = tensor_output[:, last_token_idx]  # [batch_size, hidden_dim]
    
    # Create steering vectors from only the last token representations
    steering_vectors = create_steering_vector(cbm, last_token_reps)  # [batch_size, hidden_dim]
    
    # Apply the same steering vector to all token positions
    alpha = 1  # Overall steering strength
    
    # Reshape steering_vectors to broadcast to all positions
    # [batch_size, 1, hidden_dim] for broadcasting to [batch_size, seq_length, hidden_dim]
    steering_vectors = steering_vectors.unsqueeze(1)
    
    # Apply steering to all positions
    modified_output = tensor_output + alpha * steering_vectors
    
    # Return modified output in the same format as the original
    if isinstance(output, tuple):
        return (modified_output,) + output[1:]
    else:
        return modified_output

# Model configurations, with correct layer for each
model_configs = [
    {
        "name": "chemma-2b",
        "model_path": "PATH/models/chemma-2b",
        "cbm_ckpt": "PATH/cbm/chemma-2b-cbm.ckpt",
        "layer_to_modify": 9,
    },
]

special_tokens = [
    '[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]',
    '[QED]', '[/QED]', '[LOGP]', '[/LOGP]',
    '[START_SMILES]', '[END_SMILES]', '[SEP]'
]

# Load data
eval_data = pd.read_parquet("PATH/tdc_data/oled_dataset.parquet")
sampled_data = eval_data.sample(n=100, random_state=21)

for config in model_configs:
    model_name = config["name"]
    model_path = config["model_path"]
    cbm_model_path = config["cbm_ckpt"]
    layer_to_modify = config["layer_to_modify"]

    print(f"\n===== Evaluating {model_name} =====\n")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained(model_path).eval().to('cuda')
    tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
    model.resize_token_embeddings(len(tokenizer))

    cbm = load_cbm_model(cbm_model_path)

    # Register your existing forward hook
    if model_name == "chemma-2b":
    	layer = model.model.layers[layer_to_modify]
    else:
        layer = model.model.decoder.layers[layer_to_modify]

    hook_handle = layer.register_forward_hook(modify_layer_output)

    # Sampling and evaluation
    num_samples = 100
    mini_batch = 100  # adjust to your GPU
    all_eval_rows = []

    for idx, row in tqdm(sampled_data.iterrows(), total=len(sampled_data)):
        input_text = f"[QED]{row['QED']}[/QED][LOGP]{row['LOGP']}[/LOGP][START_SMILES]"
        outputs_list = []
        for j in range(0, num_samples, mini_batch):
            curr_batch = [input_text] * min(mini_batch, num_samples - j)
            prompt = tokenizer(curr_batch, return_tensors="pt", padding=True, truncation=True)
            prompt = {k: v.to(model.device) for k, v in prompt.items()}
            outputs = model.generate(
                prompt["input_ids"],
                do_sample=True,
                top_k=50,
                temperature=1.0,
                max_length=1024,
                eos_token_id=tokenizer.convert_tokens_to_ids("[END_SMILES]")
            )
            outputs_list.extend(tokenizer.batch_decode(outputs))

        smiles_list, valid_list, clogp_list, qed_list = [], [], [], []
        for out in outputs_list:
            smiles_pattern = re.search(r'\[START_SMILES\](.*?)(\[END_SMILES\]|$)', out)
            smiles_string = smiles_pattern.group(1) if smiles_pattern else None
            smiles_list.append(smiles_string)
            valid_list.append(is_valid_smiles(smiles_string))
            clogp_list.append(clogp(smiles_string))
            qed_list.append(qed(smiles_string))

        df = pd.DataFrame({
            "smiles_string": smiles_list,
            "valid_smiles": valid_list,
            "clogp_value": clogp_list,
            "qed_value": qed_list
        })

        row_result = {
            "QED_MAE": compute_qed_mae(df, row),
            "LogP_MAE": compute_logp_mae(df, row),
            "Validity": validity(df),
            "GenEff": compute_generative_efficiency(df),
            "input_text": input_text,
            "gt_qed": row['QED'],
            "gt_logp": row['LOGP'],
        }
        all_eval_rows.append(row_result)

    eval_df = pd.DataFrame(all_eval_rows)
    summary = {
        "mean_QED_MAE": eval_df['QED_MAE'].mean(),
        "mean_LogP_MAE": eval_df['LogP_MAE'].mean(),
        "mean_Validity": eval_df['Validity'].mean(),
        "mean_GenEff": eval_df['GenEff'].mean(),
    }
    print(f"Summary for {model_name}: {summary}")

    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    file_name = f"generated_smiles_{model_name}_cbm_{current_time}.csv"
    output_dir = f"PATH/concept_representation_alignment/results/{model_name}/"
    os.makedirs(output_dir, exist_ok=True)
    eval_df.to_csv(os.path.join(output_dir, file_name), index=False)

    hook_handle.remove()
    del model
    del cbm
    torch.cuda.empty_cache()
    print(f"\n===== Finished {model_name}, GPU memory cleared =====\n")
