import wandb
import torch
import numpy as np
import pickle
import json
import argparse
import os
from transformers import Trainer
from datasets import Dataset
from sklearn.metrics import mean_absolute_error, mean_squared_error, root_mean_squared_error, r2_score
from safetensors.torch import load_file

from models.mlp import MLP
from data.loaders import get_embeddings, get_embeddings_test
from data.utils import load_managers, load_test_manager
from lm_polygraph.ue_metrics import PredictionRejectionArea

from lm_polygraph.estimators.greedy_supervised_cocoa import *
from lm_polygraph.estimators.sample_supervised_cocoa import *
from lm_polygraph.estimators.max_probability import MaximumSequenceProbability, SampledMaximumSequenceProbability
from lm_polygraph.estimators.perplexity import Perplexity, SampledPerplexity
from lm_polygraph.estimators.token_entropy import MeanTokenEntropy, SampledMeanTokenEntropy
from lm_polygraph.estimators.greedy_semantic_average_ue_average_similarity import *
from lm_polygraph.estimators.semantic_average_ue_average_similarity import *


quality_metrics = {
    'triviaqa': 'AlignScoreTargetOutput',
    'coqa': 'AlignScoreTargetOutput',
    'gsm8k': 'Accuracy',
    'wmt14_fren': 'Comet',
    'wmt19_deen': 'Comet',
    'mmlu': 'Accuracy',
    'xsum': 'AlignScoreInputOutput',
}

estimators_greedy = [SupervisedCocoaMSP(), SupervisedCocoaPPL(), SupervisedCocoaMTE()]
estimators_sample = [SampledSupervisedCocoaMSP(sample_strategy='best'), SampledSupervisedCocoaPPL(sample_strategy='best'), SampledSupervisedCocoaMTE(sample_strategy='best')]


ue_metrics = [PredictionRejectionArea(max_rejection=0.5)]

lower_bound_methods_greedy = [MaximumSequenceProbability(), Perplexity(), MeanTokenEntropy()]
lower_bound_methods_sample = [SampledMaximumSequenceProbability(sample_strategy='best'), SampledPerplexity(sample_strategy='best'), SampledMeanTokenEntropy(sample_strategy='best')]


upper_bound_methods_greedy = [
    GreedySemanticEnrichedMaxprobAveDissimilarity(),
    GreedySemanticEnrichedPPLAveDissimilarity(),
    GreedySemanticEnrichedMTEAveDissimilarity()
]
upper_bound_methods_sample =[
      SemanticEnrichedMaxprobAveDissimilarity(sample_strategy='best'),
    SemanticEnrichedPPLAveDissimilarity(sample_strategy='best'),
    SemanticEnrichedMTEAveDissimilarity(sample_strategy='best')  
]



def collate_fn(batch):
    return {
        "embeddings": torch.stack([torch.tensor(sample["embedding"], dtype=torch.float32) for sample in batch], dim=0),
        "labels": torch.tensor([sample["label"] for sample in batch], dtype=torch.float32)
    }

def to_dataset(embeddings, targets):
    return Dataset.from_list([{"embedding": emb, "label": label} for emb, label in zip(embeddings, targets)])


def evaluate_model(model_path, base_model, dataset, manager_dir, pooling_type, selected_layer, device, man_save_path='.', greedy_or_sample='greedy'):
    model = MLP(4096, 4096, 2048, 0.1)
    state_dict = load_file(os.path.join(model_path, "model.safetensors"))
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    # the rest stays unchanged
    quality_metric = quality_metrics[dataset]

    test_manager = load_test_manager(base_model, dataset, manager_dir, device)
    embeddings_test, targets_test, ids_test = get_embeddings_test( test_manager, pooling_type, selected_layer, greedy_or_sample=greedy_or_sample)
    print("Original length: ", len(test_manager.gen_metrics[('sequence', quality_metric)]))
    print("After filtering: " , len(ids_test))

    test_dataset = to_dataset(embeddings_test, targets_test)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)

    # Filter test manager in case empy embeddings present
    for stat_key, stat_values in test_manager.stats.items():
        if isinstance(stat_values, list):
            test_manager.stats[stat_key] = [stat_values[i] for i in ids_test]
        elif isinstance(stat_values, np.ndarray):
            test_manager.stats[stat_key] = stat_values[ids_test]
        else:
            print(f"Unknown type in stats: {stat_key} — skipping")

    for stat_key, stat_values in test_manager.estimations.items():
        if isinstance(stat_values, list):
            test_manager.estimations[stat_key] = [stat_values[i] for i in ids_test]
        elif isinstance(stat_values, np.ndarray):
            test_manager.estimations[stat_key] = stat_values[ids_test]
        else:
            print(f"Unknown type in stats: {stat_key} — skipping")


    for stat_key, stat_values in test_manager.gen_metrics.items():
        if isinstance(stat_values, list):
            test_manager.gen_metrics[stat_key] = [stat_values[i] for i in ids_test]
        elif isinstance(stat_values, np.ndarray):
            test_manager.gen_metrics[stat_key] = stat_values[ids_test]
        else:
            print(f"Unknown type in stats: {stat_key} — skipping")


    all_preds, all_labels = [], []
    for batch in test_loader:
        inputs = batch["embeddings"].to(device)
        with torch.no_grad():
            preds = model(inputs)
            if isinstance(preds, dict):  
                preds = preds["logits"]  
            preds = preds.squeeze().cpu().numpy()
        labels = batch["labels"].numpy()
        all_preds.extend(preds)
        all_labels.extend(labels)

    all_preds, all_labels = np.array(all_preds), np.array(all_labels)
    if greedy_or_sample=='greedy' :
        test_manager.stats['greedy_sentence_similarity_supervised'] = all_preds
    else:
        test_manager.stats['supervised_sample_sentence_similarity'] =all_preds
    
    test_manager.ue_metrics = ue_metrics

    estimators = estimators_greedy if greedy_or_sample=='greedy' else estimators_sample
    lower_bound_methods = lower_bound_methods_greedy if greedy_or_sample=='greedy' else lower_bound_methods_sample
    upper_bound_methods = upper_bound_methods_greedy if greedy_or_sample=='greedy' else upper_bound_methods_greedy

    for estimator in estimators:
        values = estimator(test_manager.stats)
        test_manager.estimations[('sequence', str(estimator))] = values

    for estimator in lower_bound_methods:
        values = estimator(test_manager.stats)
        test_manager.estimations[('sequence', str(estimator))] = values

    for estimator in upper_bound_methods:
        values = estimator(test_manager.stats)
        test_manager.estimations[('sequence', str(estimator))] = values

    test_manager.eval_ue()
    test_manager.save_path = os.path.join(man_save_path, f"{base_model}_{dataset}_{greedy_or_sample}.man")
    test_manager.save()

    
    prr_cocoa_MSP_supervised = test_manager.metrics[('sequence', str(estimators[0]), quality_metric , 'prr_0.5_normalized')]
    prr_MSP = test_manager.metrics[('sequence', str(lower_bound_methods[0]), quality_metric , 'prr_0.5_normalized')]
    prr_cocoa_MSP = test_manager.metrics[('sequence', str(upper_bound_methods[0]), quality_metric , 'prr_0.5_normalized')]

    prr_cocoa_PPL_supervised = test_manager.metrics[('sequence', str(estimators[1]), quality_metric , 'prr_0.5_normalized')]
    prr_PPL = test_manager.metrics[('sequence', str(lower_bound_methods[1]), quality_metric , 'prr_0.5_normalized')]
    prr_cocoa_PPL = test_manager.metrics[('sequence', str(upper_bound_methods[1]), quality_metric , 'prr_0.5_normalized')]

    prr_cocoa_MTE_supervised = test_manager.metrics[('sequence', str(estimators[2]), quality_metric , 'prr_0.5_normalized')]
    prr_MTE = test_manager.metrics[('sequence', str(lower_bound_methods[2]), quality_metric , 'prr_0.5_normalized')]
    prr_cocoa_MTE = test_manager.metrics[('sequence', str(upper_bound_methods[2]), quality_metric , 'prr_0.5_normalized')]



    metrics = {
        "mae": mean_absolute_error(all_labels, all_preds),
        "mse": mean_squared_error(all_labels, all_preds),
        "rmse": root_mean_squared_error(all_labels, all_preds),
        "r2": r2_score(all_labels, all_preds),
                "MSP": prr_MSP,
        "MSP_cocoa": prr_cocoa_MSP,
        "MSP_cocoa_supervised": prr_cocoa_MSP_supervised,
        "PPL": prr_PPL,
        "PPL_cocoa": prr_cocoa_PPL,
        "PPL_cocoa_supervised": prr_cocoa_PPL_supervised,
        "MTE": prr_MTE,
        "MTE_cocoa": prr_cocoa_MTE,
        "MTE_cocoa_supervised": prr_cocoa_MTE_supervised,
    }

    print("Evaluation metrics:\n", json.dumps(metrics, indent=2))
    with open(os.path.join(model_path, "eval_test_metrics.json"), "w") as f:
        json.dump(metrics, f, indent=4)
    with open(os.path.join(model_path, "test_predictions.pickle"), "wb") as f:
        pickle.dump({"predictions": all_preds, "labels": all_labels}, f)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--model", type=str, default="llama")
    parser.add_argument("--dataset", type=str, default="triviaqa")
    parser.add_argument("--manager_dir", type=str, default="nfs-stor/statml/cocoa_supervised")
    parser.add_argument("--selected_layer", type=int, default=16)
    parser.add_argument("--pooling_type", type=str, default="mean")
    parser.add_argument("--save_dir", type=str, default="enriched")
    parser.add_argument("--run_name", type=str, required=True, help="Name of the W&B run")
    parser.add_argument(
        '--greedy_or_sample',
        type=str,
        default='greedy'
    )
    return parser.parse_args()


def get_run_id_from_name(project_path: str, run_name: str):
    api = wandb.Api()
    runs = api.runs(project_path)
    for run in runs:
        if run.name == run_name:
            return run.id 
    raise ValueError(f"Run name '{run_name}' not found in project '{project_path}'")

if __name__ == "__main__":
    args = parse_args()

    wandb.login()
    wandb.init()
    project_path = ""  

    run_id = get_run_id_from_name(project_path, args.run_name)

    artifact_name = f"model-{run_id}"
    artifact = wandb.use_artifact(f"{project_path}/{artifact_name}:latest", type="model")

    # Check if it's already downloaded in a local directory
    local_path = os.path.join("artifacts", f"{artifact_name}:v1")

    if not os.path.exists(local_path):
        model_path = artifact.download()
    else:
        model_path = local_path


    evaluate_model(
        model_path=model_path,
        base_model=args.model,
        dataset=args.dataset,
        manager_dir=args.manager_dir,
        pooling_type=args.pooling_type,
        selected_layer=args.selected_layer,
        device=args.device,
        man_save_path=args.save_dir,
        greedy_or_sample=args.greedy_or_sample
    )
