import argparse
import os
from pathlib import Path
import sys
import torch
import zlib
import tqdm, json 
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import numpy as np
from filelock import FileLock

sys.path.append(str(Path(__file__).parent.parent))
from utils.utils import prepare_model

loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device='cuda' if torch.cuda.is_available() else 'cpu')
embedding_model.eval()

def raw_values_batch(model, tokenizer, example_list):
    '''
    This function takes a list of strings and returns the loss values for each token in the string
    input:
        model: the language model
        tokenizer: the tokenizer
        example_list: a list of strings

    output:
        loss_list:  a list of lists. 
                    Each list contains the loss values for each token in the string

    '''
    max_length = min(getattr(tokenizer, "model_max_length", 512), 512)
    input_ids = tokenizer(example_list, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    
    if model.device.type == "cuda":
        input_ids = {k: v.cuda() for k, v in input_ids.items()}
    
    # forward pass with no grad
    with torch.no_grad():
        outputs = model(**input_ids)
    
    labels = input_ids["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100

    # shift the labels
    shifted_labels = labels[..., 1:].contiguous()

    # shift the logits
    shifted_logits = outputs.logits[..., :-1, :].contiguous()

    # --- similarities -------
    # get predicted tokens using argmax
    predicted_tokens = torch.argmax(shifted_logits, dim=-1)
    
    # now compute similarities instead of loss
    similarity_list = []
    
    for batch_idx in range(shifted_labels.size(0)):
        batch_similarities = []
        
        pred_tokens_seq = predicted_tokens[batch_idx]  # predicted tokens for this sequence
        true_tokens_seq = shifted_labels[batch_idx]    # original tokens for this sequence
        
        # filter out -100 (padding) tokens
        mask = true_tokens_seq != -100
        valid_pred_tokens = pred_tokens_seq[mask]
        valid_true_tokens = true_tokens_seq[mask]
        
        if len(valid_pred_tokens) > 0:
            # decode tokens to text
            pred_texts = [tokenizer.decode([token_id], skip_special_tokens=True) for token_id in valid_pred_tokens]
            true_texts = [tokenizer.decode([token_id], skip_special_tokens=True) for token_id in valid_true_tokens]

            # compute embeddings
            pred_embeds = embedding_model.encode(pred_texts)
            true_embeds = embedding_model.encode(true_texts)
            
            # compute similarities
            for pred_embed, true_embed in zip(pred_embeds, true_embeds):
                # dot product similarity (like in your original function)
                score = torch.dot(torch.tensor(pred_embed), torch.tensor(true_embed)).item()
                batch_similarities.append(score)

        if len(batch_similarities) > 0:
            similarity_list.append(batch_similarities)

    # --- probabilities of input tokens -------
    
    probabilities = torch.nn.functional.log_softmax(shifted_logits, dim=-1)
    all_prob = []

    for batch_idx in range(shifted_labels.size(0)):
        input_ids_processed = shifted_labels[batch_idx]
        token_probabilities = []
        for i, token_id in enumerate(input_ids_processed):
            if token_id == -100:
                continue
            probability = probabilities[batch_idx, i, token_id].item()
            token_probabilities.append(probability)
        if len(token_probabilities) > 0:
            all_prob.append(token_probabilities)

    # check if length of loss_list, similarity_list and all_prob is the same
    assert len(similarity_list) == len(all_prob), f"Lengths are not the same: {len(similarity_list)}, {len(all_prob)}"

    # check if length of each list inside the loss_list, similarity_list and all_prob is the same
    for i in range(len(similarity_list)):
        assert len(similarity_list[i]) == len(all_prob[i]), f"Lengths of entry {i} are not the same: {len(loss_list[i])}, {len(similarity_list[i])}, {len(all_prob[i])}"

    return similarity_list, all_prob

def raw_values(model, tokenizer, example_list, batch_size = 32):
    '''
    This function takes a list of strings and returns the loss values for each token in the string
    input:
        model: the language model
        tokenizer: the tokenizer
        example_list: a list of strings
        batch_size: the batch size
    output:
        loss_list:  a list of lists. 
                    Each list contains the loss values for each token in the string
    '''
    similarity = []
    probability = []
    for i in tqdm.tqdm(range(0, len(example_list), batch_size)):
        batch = example_list[i:i + batch_size]
        similarity_batch, probability_batch = raw_values_batch(model, tokenizer, batch)
        similarity += similarity_batch
        probability += probability_batch
    return similarity, probability

def aggregate_raw_values(model, tokenizer, ref_model, ref_tokenizer, dataset, metric_list, args, batch_size = 32, text_column=None):
    if_all = text_column == "all"
    if if_all or text_column == "text":
        text_column = None
    if text_column is not None and text_column not in dataset.column_names:
        raise ValueError(f"Column {text_column} not found in dataset. Available columns: {dataset.column_names}")
    example_list = dataset[text_column] if text_column is not None else dataset["text"]
    similarity_list, all_prob = raw_values(model, tokenizer, example_list, batch_size = batch_size)

    if ref_model is not None and ref_tokenizer is not None:
        ref_similarity_list, ref_all_prob = raw_values(ref_model, ref_tokenizer, example_list, batch_size = batch_size)

    prefix = f"{text_column}_" if text_column is not None else ""
    # aggregate the results
    aggregated_results = {}
    aggregated_results.update({
        f"{prefix}similarity": similarity_list,
        f"{prefix}probability": all_prob,
    })
    if ref_model is not None and ref_tokenizer is not None:
        aggregated_results.update({
            f"{prefix}ref_similarity": ref_similarity_list,
            f"{prefix}ref_probability": ref_all_prob
        })

    if "perturbation" in metric_list and if_all:
        for perturbation in dataset.column_names:
            if perturbation not in ["text", "prefix", "suffix"]:
                perturbation_list = dataset[perturbation]
                perturbed_similarity_list, perturbed_all_prob = raw_values(model, tokenizer, perturbation_list, batch_size = batch_size)
                aggregated_results[f"{perturbation}_similarity"] = perturbed_similarity_list
                aggregated_results[f"{perturbation}_probability"] = perturbed_all_prob
                if ref_model is not None and ref_tokenizer is not None:
                    ref_perturbed_similarity_list, ref_perturbed_all_prob = raw_values(ref_model, ref_tokenizer, perturbation_list, batch_size = batch_size)
                    aggregated_results[f"{perturbation}_ref_similarity"] = ref_perturbed_similarity_list
                    aggregated_results[f"{perturbation}_ref_probability"] = ref_perturbed_all_prob

    return aggregated_results

def get_args():
    parser = argparse.ArgumentParser(description='Dataset Inference on a language model')
    parser.add_argument('--model_name', type=str, default="EleutherAI/pythia-410m-deduped", help='The name of the model to use')
    parser.add_argument('--ref_model_name', type=str, default="openai-community/gpt2-xl", help='The name of the reference model to use')
    parser.add_argument('--dataset_path', type=str, help='The path to the dataset to use')
    parser.add_argument('--batch_size', type=int, default=32, help='The batch size to use')
    parser.add_argument('--cache_dir', type=str, default="~/.cache", help='The directory to cache the model')
    parser.add_argument('--output_dir', type=str, default="results", help='The directory to save the results')
    parser.add_argument('--result_file_name', type=str, default="metrics.json", help='The name of the result file')
    parser.add_argument('--text_column', type=str, default="all", help='The column name in the dataset to use as text input')
    args = parser.parse_args()
    return args

def save_metrics(results_file, metrics_dict):
    """
    Save metrics to a JSON file with exclusive file lock.
    """
    lock_path = results_file + ".lock"
    with FileLock(lock_path):
        if os.path.exists(results_file):
            with open(results_file, 'r', encoding='utf-8') as f:
                try:
                    existing = json.load(f)
                except Exception:
                    existing = {}
        else:
            existing = {}
        existing.update(metrics_dict)
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(existing, f, ensure_ascii=False, indent=2)
    print(f"Saved metrics to {results_file}")

def main():
    args = get_args()
    results_file = f"{args.output_dir}/{args.result_file_name}"

    model_name =  args.model_name
    ref_model_name = args.ref_model_name

    model, tokenizer = prepare_model(model_name, cache_dir=args.cache_dir)

    if args.ref_model_name:
        ref_model, ref_tokenizer = prepare_model(ref_model_name, cache_dir=args.cache_dir)
    else:
        ref_model, ref_tokenizer = None, None

    dataset = load_dataset("json", data_files=args.dataset_path, split="train")

    metric_list = ["k_min_probs", "k_strip_probs", "ppl", "zlib_ratio", "k_max_probs", "perturbation", "reference_model"]

    raw_values = aggregate_raw_values(model, tokenizer, ref_model, ref_tokenizer, dataset, metric_list, args, batch_size = args.batch_size, text_column=args.text_column)

    # save the metrics
    os.makedirs(args.output_dir, exist_ok=True)
    save_metrics(results_file, raw_values)

if __name__ == "__main__":
    main()
