import torch
import numpy as np
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction, corpus_bleu
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem.Lipinski import NumHDonors, NumHAcceptors
from rdkit.Chem import QED
from rdkit.Chem.rdmolops import GetMolFrags
import selfies as sf
import sascorer
from molgeneval import filter_selfies

def compute_bleu_nltk_batch(tokenizer, generated_texts, reference_texts, verbose=True):
    smoothie = SmoothingFunction().method1
    rouge_scorer_obj = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'])
    
    bleu2_scores, bleu4_scores = [], []
    rouge1_scores, rouge2_scores, rougeL_scores = [], [], []
    meteor_scores = []
    ref_tokens_lst, pred_tokens_lst = [], []
    
    for gen, ref in zip(generated_texts, reference_texts):
        gen_tokens = tokenizer.tokenize(gen)
        ref_tokens = tokenizer.tokenize(ref)
        pred_tokens_lst.append(gen_tokens)
        ref_tokens_lst.append([ref_tokens])
        
        # BLEU-2 and BLEU-4
        bleu2 = sentence_bleu([ref_tokens], gen_tokens, weights=(0.5, 0.5), smoothing_function=smoothie)
        bleu4 = sentence_bleu([ref_tokens], gen_tokens, smoothing_function=smoothie)
        bleu2_scores.append(bleu2)
        bleu4_scores.append(bleu4)
        
        # ROUGE scores
        rouge_scores = rouge_scorer_obj.score(ref, gen)
        rouge1_scores.append(rouge_scores['rouge1'].fmeasure)
        rouge2_scores.append(rouge_scores['rouge2'].fmeasure)
        rougeL_scores.append(rouge_scores['rougeL'].fmeasure)
        
        # METEOR score
        meteor = meteor_score([ref_tokens], gen_tokens)
        meteor_scores.append(meteor)
    
    corpus_bleu_score_2 = corpus_bleu(ref_tokens_lst, pred_tokens_lst, weights=[0.5, 0.5])
    corpus_bleu_score_4 = corpus_bleu(ref_tokens_lst, pred_tokens_lst)
    
    individual_scores = {
        'bleu2': bleu2_scores,
        'bleu4': bleu4_scores,
        'rouge1': rouge1_scores,
        'rouge2': rouge2_scores,
        'rougeL': rougeL_scores,
        'meteor': meteor_scores,
    }
    averages = {k: np.mean(v) for k, v in individual_scores.items()}
    averages['corpus_bleu_2'] = corpus_bleu_score_2
    averages['corpus_bleu_4'] = corpus_bleu_score_4
    if verbose:
        print(f"Average scores: {averages}")
    return individual_scores

def compute_rouge_nltk_batch(tokenizer, generated_texts, reference_texts, rouge_types=None):
    """
    Compute ROUGE scores (rouge1, rouge2, rougeL) for each pair of generated vs. reference texts.

    Args:
        tokenizer: a HuggingFace‐style tokenizer with .tokenize().
        generated_texts: List[str] of model outputs.
        reference_texts: List[str] of target/reference strings.
        rouge_types: List of ROUGE metrics to compute (default ['rouge1','rouge2','rougeL']).

    Returns:
        List[dict]: each dict has keys in rouge_types, each mapping to a
                    rouge_score.scorer.Score(precision, recall, fmeasure).
    """
    if rouge_types is None:
        rouge_types = ['rouge1', 'rouge2', 'rougeL']
    scorer = rouge_scorer.RougeScorer(rouge_types, use_stemmer=True)

    scores = []
    for gen, ref in zip(generated_texts, reference_texts):
        # tokenize and re-join so that scoring is done on the same tokenization
        gen_tokens = tokenizer.tokenize(gen)
        ref_tokens = tokenizer.tokenize(ref)
        gen_str = " ".join(gen_tokens)
        ref_str = " ".join(ref_tokens)
        score = scorer.score(ref_str, gen_str)
        f1_sum = sum(score[r].fmeasure for r in rouge_types)/len(rouge_types)
        scores.append(f1_sum)
    return scores

def batch_bert_sentence_similarity(
    references: str,
    candidates,
    tokenizer,
    model,
    device: str = None
):
    """
    Compute cosine similarity between a reference sentence and each candidate,
    using a given HuggingFace tokenizer and model.

    Args:
        reference:  Ground‐truth sentence.
        candidates: List of sentences to compare.
        tokenizer:  Pre‐initialized HF tokenizer.
        model:      Pre‐initialized HF model (with pooler_output or last_hidden_state).
        device:     'cuda' or 'cpu' (auto‐detected if None).

    Returns:
        List of floats in [0,1], one per candidate.
    """
    # 1) Setup device & model
    assert len(references) == len(candidates), "Both lists must be same length"
    device = "cuda"
    model.to(device).eval()

    # Combine all texts so we do one forward pass
    texts = references + candidates
    encoded = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**encoded, return_dict=True)

    # Get pooled embeddings (or mean-pool if unavailable)
    if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
        embeddings = outputs.pooler_output  # shape: (2N, hidden_size)
    else:
        last_hidden = outputs.last_hidden_state  # (2N, seq_len, hidden_size)
        mask = encoded.attention_mask.unsqueeze(-1)  # (2N, seq_len, 1)
        summed = (last_hidden * mask).sum(dim=1)  # (2N, hidden_size)
        counts = mask.sum(dim=1).clamp(min=1)  # (2N, 1)
        embeddings = summed / counts  # (2N, hidden_size)

    embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

    N = len(references)
    ref_embs = embeddings[:N]  # (N, hidden_size)
    cand_embs = embeddings[N:]  # (N, hidden_size)

    # Elementwise cosine similarities
    sims = (ref_embs * cand_embs).sum(dim=1)  # (N,)
    sims = torch.clamp(sims, min=0.0)  # ensure ≥0

    return [float(s.item()) for s in sims]

def compute_reward_llama3(model, tokenizer, batch_conversations: list[list[dict]]) -> list[float]:
    """
    Computes the normalized (average) log-likelihood of the assistant's response
    in a conversation using a Llama 3 model.

    Args:
        model_name (str): The name of the Llama 3 model on the Hugging Face hub.
        batch_conversations (list[list[dict]]): A list where each item is a
            conversation in chat template format.

    Returns:
        list[float]: A list of NORMALIZED (average) log-likelihood scores for each
            assistant response in the batch.
    """
    # Check for GPU availability
    device = "cuda" if torch.cuda.is_available() else "cpu"

    all_input_ids = []
    all_labels = []
    for convo in batch_conversations:
        prompt_messages = convo[:-1]

        prompt_ids = tokenizer.apply_chat_template(prompt_messages, tokenize=True, add_generation_prompt=True)
        full_ids = tokenizer.apply_chat_template(convo, tokenize=True, add_generation_prompt=False)

        labels = [-100] * len(prompt_ids) + full_ids[len(prompt_ids):]

        all_input_ids.append(full_ids)
        all_labels.append(labels)

    # 2. Pad the batch to the same length
    # For decoder-only models, left-padding is standard practice for batch inference
    padded_inputs = tokenizer.pad({"input_ids": all_input_ids}, padding="longest", return_tensors="pt").to(device)

    padded_labels = tokenizer.pad({"input_ids": all_labels}, padding="longest", return_tensors="pt")["input_ids"].to(device)
    padded_labels[padded_labels == tokenizer.pad_token_id] = -100

    with torch.no_grad():
        outputs = model(input_ids=padded_inputs.input_ids, attention_mask=padded_inputs.attention_mask)

        # **CRITICAL SHIFTING STEP**
        # Shift logits and labels for correct causal LM loss calculation
        # Logits at position i predict token at position i+1
        logits = outputs.logits[:, :-1, :]
        labels = padded_labels[:, 1:]

        loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        per_token_loss = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
        per_token_loss = per_token_loss.view(labels.shape)

        masked_loss = per_token_loss * (labels != -100).float()
        total_neg_log_likelihood = torch.sum(masked_loss, dim=1)

        token_counts = torch.sum((labels != -100).float(), dim=1)
        token_counts = torch.clamp(token_counts, min=1)

        avg_neg_log_likelihood = total_neg_log_likelihood / token_counts
        avg_log_likelihoods = -avg_neg_log_likelihood

    return avg_log_likelihoods.cpu().tolist()

def num_lipinski_violations(mol):
    """
    Returns the number of Lipinski rule‐of‐5 violations for mol:
      1) MW   ≤ 500
      2) logP ≤ 5
      3) HBD  ≤ 5
      4) HBA  ≤ 10
    """
    mw    = Descriptors.MolWt(mol)
    clogp = Descriptors.MolLogP(mol)
    hbd   = NumHDonors(mol)
    hba   = NumHAcceptors(mol)

    violations = 0
    if mw    > 500: violations += 1
    if clogp >   5: violations += 1
    if hbd   >   5: violations += 1
    if hba   >  10: violations += 1

    return violations

def eval_qual(mols):
    mol_score = []
    for mol in mols:
        try:
            sm_mol = sf.decoder(filter_selfies(mol.replace(" ", '')))
            sm_mol = Chem.MolFromSmiles(sm_mol)
            Chem.SanitizeMol(sm_mol)
            qed_score = QED.qed(sm_mol)
            sa_score = sascorer.calculateScore(sm_mol)
            sa_norm = 1 - (sa_score - 1) / 9
            violations = num_lipinski_violations(sm_mol)  # 0–4
            lip_penalty = 1.0 / (1 + violations)
            quality = qed_score * 0.6 + sa_norm * 0.3 + lip_penalty * 0.1
            mol_score.append([qed_score, sa_norm, lip_penalty])
        except:
            mol_score.append([0,0,0])

    return mol_score

def identify_smiles_components(smiles_string):
    """
    Identifies if a SMILES string represents a single molecule or multiple fragments.
    """
    try:
        mol = Chem.MolFromSmiles(smiles_string)
    except:
        return 0
    if mol is None:
        return 0
    frags = GetMolFrags(mol)
    num_frags = len(frags)
    if num_frags >= 1:
        return num_frags
    else:
        # This case is unlikely for a valid SMILES but included for completeness.
        raise ValueError("Unexpected number of fragments.")