from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse
import pandas as pd
import os
import re
import datetime
from tqdm import tqdm
import torch
import sys
import yaml
import numpy as np
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from eval_metrics import clogp, qed, is_valid_smiles
from collections import OrderedDict
from concept_alignment.models.cbm import ConceptBottleneckModel

torch.cuda.empty_cache()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data",
        type=str,
        default="zinc",
        help="Dataset to run the system on",
    )

    parser.add_argument(
        "--model",
        type=str,
        default="chemlactica-125m",
        help="LLM Model to use",
    )

    parser.add_argument(
        "--concept",
        type=int,
        default=0,
        help="Concept to modify",
    )
    args = parser.parse_args()
    return args


layer_modify_dict = {
    "chemlactica-125m" : 6,
    "chemlactica-1.3b" : 12,
    "chemma-2b" : 9
}

script_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(script_dir, "configs/cav_eval.yaml")
args = parse_args()
if os.path.exists(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    modify_func_name = config.get("modify_func", "additive")  # Default to 'additive'
else:
    modify_func_name = "additive"

data = args.data.lower()
model_name = args.model
layer_to_modify = layer_modify_dict[args.model]
cbm_model_path = config.get("cbm_model_path").format(model=model_name)
data_path = config.get("data_dir")
concept_to_modify = args.concept


print("CAV Alignment Function: ", modify_func_name)

print("Eval Dataset: ", data)

if data == "zinc":
    sampled_data = pd.read_csv(f"{data_path}/zinc_eval.csv")
elif data == "moses":
    eval_data = pd.read_parquet(f"{data_path}/moses_eval.parquet")
    sampled_data = eval_data.sample(n=2000, random_state=21)
elif data == "chembl":
    eval_data = pd.read_parquet(f"{data_path}/chembl_eval.parquet")
    sampled_data = eval_data.sample(n=2000, random_state=21)
elif data == "oled":
    eval_data = pd.read_parquet(f"{data_path}/oled_dataset.parquet")
    sampled_data = eval_data.sample(n=2000, random_state=21)

# Load the tokenizer
MODEL_PATH= None
if model_name == "chemlactica-125m":
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).eval().to('cuda')
elif model_name == "chemlactica-1.3b":
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).eval().to('cuda')
elif model_name == "chemma-2b":
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    tokenizer.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).eval().to('cuda')

special_tokens = ['[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]', '[QED]', '[/QED]', '[LOGP]', '[/LOGP]', '[START_SMILES]', '[END_SMILES]', '[SEP]']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
model.resize_token_embeddings(len(tokenizer))


def load_cbm_model(model_path, device='cuda'):
    """Load trained CBM model from state dictionary"""
    # Load the state dictionary
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)

    # Extract the model state dict based on checkpoint structure
    if isinstance(checkpoint, OrderedDict):
        state_dict = checkpoint
    elif isinstance(checkpoint, dict):
        # If it's inside a checkpoint dict with various possible keys
        if 'cbm_state_dict' in checkpoint:
            state_dict = checkpoint['cbm_state_dict']
        elif 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        elif 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint  # Use the whole dict if no specific keys found
    else:
        # If it's the full model
        return checkpoint

    # Get hidden dimension from the model
    hidden_dim = model.config.hidden_size

    # Create CBM with matching dimension
    cbm = ConceptBottleneckModel(
        hidden_dim=hidden_dim,
        num_concepts=2,  # Assuming QED and LOGP
        concept_dim=hidden_dim
    ).to(device)

    # Fix the state dict key mismatch
    new_state_dict = {}
    for key, value in state_dict.items():
        # Remove 'cbm.' prefix if it exists in the checkpoint keys
        if key.startswith('cbm.cbm.'):
            new_key = key.replace('cbm.cbm.', 'cbm.')
            new_state_dict[new_key] = value
        elif key.startswith('cbm.'):
            if 'concept_embeddings' in key:
                new_key = 'concept_embeddings'
                new_state_dict[new_key] = value
            else:
                new_key = key.replace('cbm.', '')
                new_state_dict[new_key] = value
        else:
            # Keep other keys as they are
            new_state_dict[key] = value

    # Try strict=False to ignore missing keys
    cbm.load_state_dict(new_state_dict, strict=False)
    cbm.eval()

    return cbm

cbm = load_cbm_model(cbm_model_path)

def modify_layer_output(module, input, output):
    """Additive steering with CBM - vector from last token applied to all tokens"""
    # Handle different output formats
    if isinstance(output, tuple):
        tensor_output = output[0]
    else:
        tensor_output = output

    # Get dimensions
    batch_size, seq_length, hidden_dim = tensor_output.shape

    # Get the last token representation for each sample in the batch
    last_token_idx = seq_length - 1
    last_token_reps = tensor_output[:, last_token_idx]  # [batch_size, hidden_dim]

    # Create steering vectors from only the last token representations
    _, steering_vectors = create_steering_vector(cbm, last_token_reps)  # [batch_size, hidden_dim]
    # Apply the same steering vector to all token positions
    alpha = 1  # Overall steering strength

    # Reshape steering_vectors to broadcast to all positions
    # [batch_size, 1, hidden_dim] for broadcasting to [batch_size, seq_length, hidden_dim]
    steering_vectors = steering_vectors.unsqueeze(1)

    # Apply steering to all positions
    modified_output = tensor_output + alpha * steering_vectors

    # Return modified output in the same format as the original
    if isinstance(output, tuple):
        return (modified_output,) + output[1:]
    else:
        return modified_output

steering_vals = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
concept_vals = None

def create_steering_vector(cbm, h_n):
    global concept_vals
    concept_embeds = cbm.concept_embeddings.unsqueeze(0)

    with torch.no_grad():
        concept_values, concept_vectors = cbm(h_n) #concept vectors = (batch size, num_concepts, concept_dim), concept_values = (batch_size, num_concepts)
        if steering_idx != -1:

            concept_values[:,concept_to_modify] = steering_vals[steering_idx]
            concept_vectors = concept_values.unsqueeze(-1) * concept_embeds

        steering_vector = concept_vectors.sum(dim=1)
        if steering_idx == -1:

            concept_vals = concept_values
    return concept_values, steering_vector

MODIFY_FUNCTIONS = {
    "additive": modify_layer_output,
}

selected_modify_func = MODIFY_FUNCTIONS.get(modify_func_name, modify_layer_output)

if model_name == "chemma-2b":
    layer = model.model.layers[layer_to_modify]
else:
    layer = model.model.decoder.layers[layer_to_modify]

hook_handle = layer.register_forward_hook(selected_modify_func)

batch_size = 16
eval_set = []

input_texts = [f"</s>[QED]{row['QED']}[/QED][LOGP]{row['LOGP']}[/LOGP][START_SMILES]" for _, row in sampled_data.iterrows()]

# Process in batches
mod_max = []
mod_min = []
oth_max = []
oth_min = []
valid_thresh = int((1 + len(steering_vals)) * config.get("valid_thresh"))
for i in tqdm(range(0, len(input_texts), batch_size), total=len(input_texts) // batch_size + 1):
    batch_inputs = input_texts[i:i + batch_size]

    # Tokenize entire batch
    prompt = tokenizer(batch_inputs, return_tensors="pt", padding=True, truncation=True).to(model.device)

    steer_val_record = np.zeros((len(batch_inputs), 12)) # arr[:3,:]
    mod_prop_steer_record = np.zeros_like(steer_val_record) # arr[:3, :]
    oth_prop_steer_record = np.zeros_like(steer_val_record) # arr[:3, :]

    for steering_idx in range(-1, 11):

        # Generate batch output
        output = model.generate(
            prompt.input_ids,
            do_sample=True,
            max_length=1024,
            return_dict_in_generate=True,
            output_scores=True,
            eos_token_id=tokenizer.convert_tokens_to_ids("[END_SMILES]")
        )

        # Decode outputs in batch
        batch_outputs = tokenizer.batch_decode(output.sequences)

        # Extract SMILES and evaluate properties
        idx = 0
        for input_text, out in zip(batch_inputs, batch_outputs):
            smiles_pattern = re.search(r'\[START_SMILES\](.*?)(\[END_SMILES\]|$)', out)
            smiles_string = smiles_pattern.group(1) if smiles_pattern else None

            valid_smiles = is_valid_smiles(smiles_string)
            if valid_smiles:
                clogp_value = clogp(smiles_string)
                qed_value = qed(smiles_string)
            else:
                clogp_value = -np.inf
                qed_value = -np.inf
            steer_val = concept_vals[idx]
            if steering_idx != -1:
                steer_val[concept_to_modify] = steering_vals[steering_idx]

            steer_val_record[idx, steering_idx+1] = steer_val[concept_to_modify]
            if concept_to_modify == 0:
                mod_prop_steer_record[idx, steering_idx+1] = qed_value
                oth_prop_steer_record[idx, steering_idx+1] = clogp_value
            else:
                mod_prop_steer_record[idx, steering_idx+1] = clogp_value
                oth_prop_steer_record[idx, steering_idx+1] = qed_value
            idx += 1

    counter_ex = 0
    for idx, input_text in enumerate(batch_inputs):
        x_vals = steer_val_record[idx,:]
        mod_prop_vals = mod_prop_steer_record[idx, :]
        oth_prop_vals = oth_prop_steer_record[idx,:]

        if not np.count_nonzero(mod_prop_vals != -np.inf) >= valid_thresh:
            continue
        to_keep_vals = (mod_prop_vals != -np.inf)
        x_vals = x_vals[to_keep_vals]
        mod_prop_vals = mod_prop_vals[to_keep_vals]
        oth_prop_vals = oth_prop_vals[to_keep_vals]

        gt_logp = float((re.search("\[LOGP\](.*)\[/LOGP\]", input_text)).group(1))
        gt_qed = float((re.search("\[QED\](.*)\[/QED\]", input_text)).group(1))
        if concept_to_modify == 0:
            mod_prop_vals = np.abs(mod_prop_vals - gt_qed)
            oth_prop_vals = np.abs(oth_prop_vals - gt_logp)
        else:
            mod_prop_vals = np.abs(mod_prop_vals - gt_logp)
            oth_prop_vals = np.abs(oth_prop_vals - gt_qed)
        mod_max.append(np.max(mod_prop_vals))
        mod_min.append(np.min(mod_prop_vals))
        oth_max.append(np.max(oth_prop_vals))
        oth_min.append(np.min(oth_prop_vals))

        counter_ex += 1

        eval_set.append({
            'input_text': batch_inputs[idx],
            'orig_mod_concept_val': x_vals[0],
            'orig_other_concept_val': concept_vals[idx, 1-concept_to_modify],
            'mod_max_dev': np.max(mod_prop_vals),
            'mod_min_dev': np.min(mod_prop_vals),
            'oth_max_dev': np.max(oth_prop_vals),
            'oth_min_dev': np.min(oth_prop_vals),
        })
    print(f"{counter_ex} samples used for calculation from a batch of {len(batch_inputs)} samples.")

prop_dict ={
    0: "QED",
    1: "LogP"
}

print(f"Total samples used in calculation are: {len(mod_max)}")
prop_str = prop_dict[concept_to_modify]
oth_prop_str = prop_dict[1-concept_to_modify]
print(f"{data}: {prop_str} modified: {prop_str} max dev: {round(np.mean(mod_max),7)} | {prop_str} min dev: {round(np.mean(mod_min),7)}")
print(f"{data}: {prop_str} modified: {oth_prop_str} max dev: {round(np.mean(oth_max),7)} | {oth_prop_str} min dev: {round(np.mean(oth_min),7)}")
eval_df = pd.DataFrame(eval_set)

current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"max_min_dev_mod{concept_to_modify}_cbm_{modify_func_name}_tdc.csv"

save_path = f"../concept_representation_alignment/results/{model_name}/{data}"
os.makedirs(save_path, exist_ok=True)
eval_df.to_csv(f"{save_path}/{file_name}", index=False)
