import sys  
sys.path.append("/workspace")

import os
import pandas as pd  
import torch  
from tqdm import tqdm
from torch.utils.data import random_split  
from torch_geometric.loader import DataLoader  

from transformers import BertTokenizer
from dataloaders.dataset import ClaspDataset
from dataloaders.common import seed_worker  
from models.metric_learning import ClaspModel  
from train_OnMemory import load_metadata_and_embeddings  
from torch_geometric.loader import DataLoader  
from hydra import initialize, compose
from sklearn.metrics import roc_auc_score

import matplotlib.pyplot as plt

from pathlib import Path
import re
import numpy as np
from torch.nn.functional import cosine_similarity
# from models.utils import normalize_embedding
  
from dataloaders.common import seed_worker  
import pickle

from utils.embedding_utils import encode_texts, predict_embeddings


def calculate_material_category_similarities(target_embeddings, categories, tokenizer, text_encoder, cfg, device):
    """
    Computes the similarity between embeddings of multiple substances and a text embedding representing a material category.
    """
    category_embeddings = [encode_texts([category], tokenizer, text_encoder, cfg, device) for category in categories]
    category_embeddings = torch.stack(category_embeddings).squeeze(1)  

    # all_similarities = cosine_similarity(target_embeddings, category_embeddings)
    all_similarities = cosine_similarity(target_embeddings[:, None, :], category_embeddings[None, :, :], dim=2)
    return all_similarities

def filter_material_functions(titles_series, keywords, threshold=50):
    material_functions = []
    for keyword in keywords:
        if titles_series.str.contains(keyword, case=False).sum() >= threshold:
            material_functions.append(keyword)
    return material_functions
    
def calculate_roc_curve(dist_label_pairs):
    """
    Given a list of distance and label pairs, this function calculates the ROC curve.

    Args:
        dist_label_pairs (pd.DataFrame): DataFrame containing pairs of distance and label. 
                                          The label indicates whether it contains a keyword or not.

    Returns:
        list: A list of tuples, where each tuple contains the False Positive Rate (FPR) and 
              True Positive Rate (TPR) at each threshold.
    """
    cumsum_pos = 0
    roc_data = []
    total_num = len(dist_label_pairs)
    total_pos = dist_label_pairs["label"].sum()
    total_neg = total_num - total_pos
    for i in range(total_num):
        cumsum_pos += dist_label_pairs.iloc[i]['label']
        cumsum_neg = i+1 - cumsum_pos
        roc_data.append([cumsum_neg/total_neg, cumsum_pos/total_pos])
    
    roc_data.sort()
    roc_data.insert(0, [0.0, 0.0])
    
    return roc_data


keyword_variations = {
    'ferromagnetic': ['ferromagnetic', 'ferromagnetism'],
    'ferroelectric': ['ferroelectric', 'ferroelectricity'],
    'semiconductor': ['semiconductor', 'semiconductive', 'semiconductivity'],
    'electroluminescence': ['electroluminescence', 'electroluminescent'],
    'thermoelectric': ['thermoelectric', 'thermoelectricity'],
    'superconductor': ['superconductor', 'superconductive', 'superconductivity'],
}

def eval_roc(df, output_cry, tokenizer, text_encoder, cfg, keyword_variations):
    """
    Version without handling for keyword variations
    """
    plt.figure(figsize=(10, 7))
    auc_scores = []
    roc_data_list = []

    for query_keyword in tqdm(list(keyword_variations.keys())):
        category_embedding = encode_texts([query_keyword], tokenizer, text_encoder, cfg, device)

        dist = 1 - cosine_similarity(output_cry[:, None, :], category_embedding[None, :, :], dim=2)

        dist_label_pairs = pd.DataFrame({"label":df['titles'].str.contains(query_keyword, case=False),
                                        "dist":dist.squeeze()})

        dist_label_pairs = dist_label_pairs.sort_values("dist", ascending=True)

        roc_data = calculate_roc_curve(dist_label_pairs)
        roc_data_list.append(roc_data)

        # Extract x and y values for the plot
        x = [data[0] for data in roc_data]
        y = [data[1] for data in roc_data]

        plt.plot(x, y, label=query_keyword)
        # Calculate AUC for each keyword and add it to the list
        auc_score = roc_auc_score(y_true=dist_label_pairs['label'], y_score=-dist_label_pairs['dist']) # 距離が小さいほうがスコアが良いのでマイナスつける
        auc_scores.append(auc_score)
        print(f'AUC for {query_keyword}: {auc_score}')

    # Create a DataFrame for AUC scores
    auc_df = pd.DataFrame({"keyword": list(keyword_variations.keys()), "auc_score": auc_scores})
            
    # Calculate and print average AUC
    average_auc = sum(auc_scores) / len(auc_scores)
    print(f'Average AUC: {average_auc}')

    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.grid(True)
    plt.legend()
    fig = plt.gcf()
    
    return auc_df, fig, roc_data_list

def eval_roc_for_keyword_variations(df, output_cry, tokenizer, text_encoder, cfg, keyword_variations):
    """
    Version with handling for keyword variations
    """
    plt.figure(figsize=(10, 7))
    auc_scores = []
    roc_data_list = []

    for query_keyword in tqdm(list(keyword_variations.keys())):
        variations = keyword_variations.get(query_keyword, [query_keyword])
        category_embeddings = [encode_texts([variation], tokenizer, text_encoder, cfg, device) for variation in variations]

        # Calculate the minimum distance for each variation
        dists = [1 - cosine_similarity(output_cry[:, None, :], category_embedding[None, :, :], dim=2) for category_embedding in category_embeddings]
        min_dist = torch.min(torch.stack(dists), dim=0)[0]

        # Check if any of the variations is contained in the titles
        label = np.any([df['titles'].str.contains(variation, case=False) for variation in variations], axis=0)

        dist_label_pairs = pd.DataFrame({"label":label,
                                         "dist":min_dist.squeeze()})

        dist_label_pairs = dist_label_pairs.sort_values("dist", ascending=True)

        roc_data = calculate_roc_curve(dist_label_pairs)
        roc_data_list.append(roc_data)

        # Extract x and y values for the plot
        x = [data[0] for data in roc_data]
        y = [data[1] for data in roc_data]

        plt.plot(x, y, label=query_keyword)
        
        # Calculate AUC for each keyword and add it to the list
        auc_score = roc_auc_score(y_true=dist_label_pairs['label'], y_score=-dist_label_pairs['dist']) # 距離が小さいほうがスコアが良いのでマイナスつける
        auc_scores.append(auc_score)
        print(f'AUC for {query_keyword}: {auc_score}')
    
    # Create a DataFrame for AUC scores
    auc_df = pd.DataFrame({"keyword": list(keyword_variations.keys()), "auc_score": auc_scores})
    
    # Calculate and print average AUC
    average_auc = auc_df['auc_score'].mean()
    print(f'Average AUC: {average_auc}')
        
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.grid(True)
    plt.legend()
    fig = plt.gcf()
    # plt.show()
    
    return auc_df, fig, roc_data_list



if __name__ == "__main__":
    # target_dataset = "val"
    target_dataset = "test"
    print(f"target dataset: {target_dataset}")
    print("loading metadata...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Change here
    metadata_and_embeddings = load_metadata_and_embeddings(load_path='/workspace/data/cod_metadata_20240331_splitted_remaining.csv',
                                                       cod_basepath='/cod')  
    dataset_rootpath = '/workspace/data/cod_full_20240331'
    print(f"dataset_rootpath: {dataset_rootpath}")

    dataset = ClaspDataset(input_dataframe=metadata_and_embeddings, 
                        tokenizer=None, 
                        max_token_length=64,
                        root=dataset_rootpath)  

    # Prepare data loaders for train, validation and test  
    dataset_size = len(dataset)  
    train_size = int(0.8 * dataset_size)  
    val_size = int(0.1 * dataset_size)  
    test_size = dataset_size - train_size - val_size  

    generator = torch.Generator().manual_seed(42)  
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=generator)  

    if target_dataset == "val":
        df = pd.DataFrame({"titles":[val_dataset.dataset.data["title"][i] for i in val_dataset.indices], 
                    "id": [val_dataset.dataset.data["material_id"][i] for i in val_dataset.indices]})
        dataloader = DataLoader(val_dataset, batch_size=2048, shuffle=False, num_workers=2, drop_last=False,
                            pin_memory=True,persistent_workers=True,
                            worker_init_fn=seed_worker) 
    elif target_dataset == "test":
        df = pd.DataFrame({"titles":[test_dataset.dataset.data["title"][i] for i in test_dataset.indices], 
                    "id": [test_dataset.dataset.data["material_id"][i] for i in test_dataset.indices]})
        dataloader = DataLoader(test_dataset, batch_size=2048, shuffle=False, num_workers=2, drop_last=False,
                            pin_memory=True,persistent_workers=True,
                            worker_init_fn=seed_worker) 

    # change here
    config_paths = [
    "../outputs/2024-08-14/ft_cosface_s3_m05_lr1e-6_0813_ep2050/13-36-10/version_0",
    "../outputs/2024-08-13/cosface_s3_m05_lr2e-5/08-01-05/version_0",
    ]

    for config_path in tqdm(config_paths):
        print(f"config path: {config_path}")
        checkpoint_dir = f"{config_path}/model_checkpoint/"

        checkpoint_files = ["last.ckpt"]  # Include "last.ckpt" by default
        if any(keyword in config_path for keyword in ["ft_", "full_"]): # In the case of fine tuning, evaluate not only the last but also intermediate step checkpoints.
            checkpoint_files.extend(f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt'))

        for checkpoint_file in checkpoint_files:
            checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
            print(f"check point: {checkpoint_path}")
            with initialize(config_path, version_base='1.1'):
                cfg = compose(config_name="hparams")
                cfg.freeze_text_encoders = True

            model = ClaspModel.load_from_checkpoint(checkpoint_path, cfg=cfg, train_loader=None, val_loader=None)
            model.to(device)
            model.eval()

            tokenizer = BertTokenizer.from_pretrained(cfg.hf_textencoder_model_id)
            text_encoder = model.model_text.to(device)
            output_cry, output_text = predict_embeddings(model, dataloader, device)

            # Evaluate without considering keyword variations
            auc_df, fig, roc_data = eval_roc(df, output_cry, tokenizer, text_encoder, cfg, keyword_variations)
            save_filename_root = checkpoint_path.replace("../outputs/", "").replace("/","_").replace("model_checkpoint", "")
            plt.tight_layout() 
            fig.savefig(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_curve_{target_dataset}.pdf')
            plt.close()
            auc_df.to_csv(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_table_{target_dataset}.csv', index=False)

            # Save the resut to pickle
            with open(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_data_{target_dataset}.pkl', 'wb') as f:
                pickle.dump(roc_data, f)

            # Evaluate with considering keyword variations
            auc_df, fig, roc_data = eval_roc_for_keyword_variations(df, output_cry, tokenizer, text_encoder, cfg, keyword_variations)
            save_filename_root = checkpoint_path.replace("../outputs/", "").replace("/","_").replace("model_checkpoint", "")
            
            plt.tight_layout() 
            fig.savefig(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_curve_keyword_variations_{target_dataset}.pdf')
            plt.close()
            auc_df.to_csv(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_table_keyword_variations_{target_dataset}.csv', index=False)

            # Save the resut to pickle
            with open(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_data_keyword_variations_{target_dataset}.pkl', 'wb') as f:
                pickle.dump(roc_data, f)

# Usage:
# Set checkpoint root dir to `config_paths` and run this script
# Example:
# python eval_zero_shot_roc.py 