import sys  
sys.path.append("/workspace")

import os
import warnings  
import pandas as pd  
import torch  
import yaml
from tqdm import tqdm
import pytorch_lightning as pl  
from torch.utils.data import random_split  
from torch_geometric.loader import DataLoader  
# from pytorch_lightning import Trainer, loggers  
# from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
# import hydra
# from omegaconf import DictConfig
# from omegaconf import OmegaConf  
# from hydra.core.hydra_config import HydraConfig
from transformers import BertTokenizer
from dataloaders.dataset import ClaspDataset, ClaspOnDiskDataset
from dataloaders.common import generate_full_path, seed_worker  
from models.cgcnn import CGCNN  
from models.metric_learning import ClaspModel  
from train import load_metadata_and_embeddings, load_caption_dataframe
from torch_geometric.loader import DataLoader  
from hydra import initialize, compose
from sklearn.metrics import roc_auc_score

import matplotlib.pyplot as plt

from pathlib import Path
import re
import numpy as np
from torch.nn.functional import cosine_similarity
from itertools import chain
# from models.utils import normalize_embedding
  
from dataloaders.common import seed_worker  
import pickle

from utils.embedding_utils import encode_texts, calculate_material_category_similarities, predict_embeddings
# from eval_scripts.top_k_acc import calculate_top_k_accuracy, plot_topk_accuracy_for_config


def calculate_material_category_similarities(target_embeddings, categories, tokenizer, text_encoder, cfg, device):
    """
    Calculate the similarity between the embeddings of multiple materials and the text embedding representing the category.
    """
    category_embeddings = [encode_texts([category], tokenizer, text_encoder, cfg, device) for category in categories]
    category_embeddings = torch.stack(category_embeddings).squeeze(1)  

    # all_similarities = cosine_similarity(target_embeddings, category_embeddings)
    all_similarities = cosine_similarity(target_embeddings[:, None, :], category_embeddings[None, :, :], dim=2)
    return all_similarities

def filter_material_functions(titles_series, keywords, threshold=50):
    material_functions = []
    for keyword in keywords:
        if titles_series.str.contains(keyword, case=False).sum() >= threshold:
            material_functions.append(keyword)
    return material_functions
    
def calculate_roc_curve(dist_label_pairs):
    """
    Given a list of distance and label pairs, this function calculates the ROC curve.

    Args:
        dist_label_pairs (pd.DataFrame): DataFrame containing pairs of distance and label. 
                                          The label indicates whether it contains a keyword or not.

    Returns:
        list: A list of tuples, where each tuple contains the False Positive Rate (FPR) and 
              True Positive Rate (TPR) at each threshold.
    """
    cumsum_pos = 0
    roc_data = []
    total_num = len(dist_label_pairs)
    total_pos = dist_label_pairs["label"].sum()
    total_neg = total_num - total_pos
    for i in range(total_num):
        cumsum_pos += dist_label_pairs.iloc[i]['label']
        cumsum_neg = i+1 - cumsum_pos
        roc_data.append([cumsum_neg/total_neg, cumsum_pos/total_pos])
    
    roc_data.sort()
    roc_data.insert(0, [0.0, 0.0])
    
    return roc_data

keyword_variations = {
    'ferromagnetic': ['ferromagnetic', 'ferromagnetism'],
    'ferroelectric': ['ferroelectric', 'ferroelectricity'],
    'semiconductor': ['semiconductor', 'semiconductive', 'semiconductivity'],
    'electroluminescence': ['electroluminescence', 'electroluminescent'],
    'photoluminescence': ['photoluminescence', 'photoluminescent'],
    'thermochromic':['thermochromic', 'thermochromism'],
}

def eval_roc_for_keyword_variations(df, output_cry, tokenizer, text_encoder, cfg, keyword_variations):
    plt.figure(figsize=(10, 7))
    auc_scores = []
    roc_data_list = []  
    
    for query_keyword in tqdm(list(keyword_variations.keys())):
        variations = keyword_variations.get(query_keyword, [query_keyword])
        print("variations: " + str(variations))
        category_embeddings = [encode_texts(variation, tokenizer, text_encoder, cfg, device) for variation in variations]
        # Calculate the minimum distance for each variation
        dists = [1 - cosine_similarity(output_cry[:, None, :], category_embedding[None, :, :], dim=2) for category_embedding in category_embeddings]
        min_dist = torch.min(torch.stack(dists), dim=0)[0]

        # Check if any of the variations is contained in the titles
        label = np.any([df['Title'].str.contains(variation, case=False) for variation in variations], axis=0)
        dist_label_pairs = pd.DataFrame({"label":label,
                                         "dist":min_dist.squeeze()})

        dist_label_pairs = dist_label_pairs.sort_values("dist", ascending=True)

        roc_data = calculate_roc_curve(dist_label_pairs)
        roc_data_list.append(roc_data)  

        # Extract x and y values for the plot
        x = [data[0] for data in roc_data]
        y = [data[1] for data in roc_data]

        plt.plot(x, y, label=query_keyword)
        
        # Calculate AUC for each keyword and add it to the list
        auc_score = roc_auc_score(y_true=dist_label_pairs['label'], y_score=-dist_label_pairs['dist']) 
        auc_scores.append(auc_score)
        print(f'AUC for {query_keyword}: {auc_score}')
    
    # Create a DataFrame for AUC scores
    auc_df = pd.DataFrame({"keyword": list(keyword_variations.keys()), "auc_score": auc_scores})
    
    # Calculate and print average AUC
    average_auc = auc_df['auc_score'].mean()
    print(f'Average AUC: {average_auc}')
        
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.grid(True)
    plt.legend()
    fig = plt.gcf()
    # plt.show()
    
    return auc_df, fig, roc_data_list


def eval_roc_for_keyword_variations_self(df, output_cry, tokenizer, text_encoder, cfg, keyword_variations, keywords_list):
    plt.figure(figsize=(10, 7))
    auc_scores = []
    roc_data_list = []  
    
    for query_keyword in tqdm(list(keyword_variations.keys())):
        variations = keyword_variations.get(query_keyword, [query_keyword])
        print("variations: " + str(variations))
        category_embeddings = [encode_texts([variation] + keywords_list, tokenizer, text_encoder, cfg, device) for variation in variations]
        # Calculate the minimum distance for each variation
        dists = [1 - cosine_similarity(output_cry[:, None, :], category_embedding[None, :, :], dim=2) for category_embedding in category_embeddings]
        min_dist = torch.min(torch.stack(dists), dim=0)[0]

        # Check if any of the variations is contained in the titles
        label = np.any([df['titles'].str.contains(variation, case=False) for variation in variations], axis=0)
        dist_label_pairs = pd.DataFrame({"label":label,
                                         "dist":min_dist.squeeze()})

        dist_label_pairs = dist_label_pairs.sort_values("dist", ascending=True)

        roc_data = calculate_roc_curve(dist_label_pairs)
        roc_data_list.append(roc_data) 

        # Extract x and y values for the plot
        x = [data[0] for data in roc_data]
        y = [data[1] for data in roc_data]

        plt.plot(x, y, label=query_keyword)
        
        # Calculate AUC for each keyword and add it to the list
        auc_score = roc_auc_score(y_true=dist_label_pairs['label'], y_score=-dist_label_pairs['dist']) 
        auc_scores.append(auc_score)
        print(f'AUC for {query_keyword}: {auc_score}')
    
    # Create a DataFrame for AUC scores
    auc_df = pd.DataFrame({"keyword": list(keyword_variations.keys()), "auc_score": auc_scores})
    
    # Calculate and print average AUC
    average_auc = auc_df['auc_score'].mean()
    print(f'Average AUC: {average_auc}')
        
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.grid(True)
    plt.legend()
    fig = plt.gcf()
    # plt.show()
    
    return auc_df, fig, roc_data_list

if __name__ == "__main__":
    target_dataset = "test"
    print(f"target dataset: {target_dataset}")
    print("loading metadata...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    metadata = load_metadata_and_embeddings(load_path='/data/cod_metadata_20240907.csv',
                                                            cod_basepath='/cod')
    keywords_df = load_caption_dataframe("/workspace/generated_data/cod_full_20240331_full_meta-llamaLlama-3_1-8B-Instruct.json")
    keywords_df['ID'] = keywords_df['ID'].astype(int)
    metadata_and_embeddings = pd.merge(metadata, keywords_df, left_on='file', right_on='ID')
    metadata_and_embeddings.drop(columns="ID", inplace=True)
    metadata_and_embeddings.drop(columns="Keywords", inplace=True)

    train_dataset = ClaspDataset(input_dataframe=metadata_and_embeddings, 
                        tokenizer=None, 
                        max_token_length=64,
                        root=os.path.join("/workspace/data/cod_full_20240331_full_meta-llamaLlama-3_1-8B-Instruct_ft", "_train"))  
    val_dataset = ClaspDataset(input_dataframe=metadata_and_embeddings, 
                        tokenizer=None, 
                        max_token_length=64,
                        root=os.path.join("/workspace/data/cod_full_20240331_full_meta-llamaLlama-3_1-8B-Instruct_ft", "_val"))  
    test_dataset = ClaspDataset(input_dataframe=metadata_and_embeddings, 
                        tokenizer=None, 
                        max_token_length=64,
                        root=os.path.join("/workspace/data/cod_full_20240331_full_meta-llamaLlama-3_1-8B-Instruct_ft", "_test")) 

    # Prepare data loaders for train, validation and test   
    test_size = 1
    train_dataset= random_split(train_dataset, [test_size])[0]
    val_dataset= random_split(val_dataset, [test_size])[0]
    test_dataset= random_split(test_dataset, [test_size])[0]

    train_df = pd.DataFrame({"titles":[train_dataset.dataset.data["title"][i] for i in train_dataset.indices],
                            "ID":[train_dataset.dataset.data["material_id"][i] for i in train_dataset.indices]})
    train_keywords = train_df.sample(n=1023,random_state=0)["titles"].tolist()
 
    if target_dataset == "val":
        df = pd.DataFrame({"titles":[val_dataset.dataset.data["title"][i] for i in val_dataset.indices], 
                    "id": [val_dataset.dataset.data["material_id"][i] for i in val_dataset.indices],})
        dataloader = DataLoader(val_dataset, batch_size=1024, shuffle=False, num_workers=2, drop_last=False,
                            pin_memory=True,persistent_workers=True,
                            worker_init_fn=seed_worker) 
    elif target_dataset == "test":
        df = pd.DataFrame({"titles":[test_dataset.dataset.data["title"][i] for i in test_dataset.indices], 
                    "id": [test_dataset.dataset.data["material_id"][i] for i in test_dataset.indices],})
        dataloader = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=2, drop_last=False,
                            pin_memory=True,persistent_workers=True,
                            worker_init_fn=seed_worker)

    # for test
    config_paths = [
        "../outputs/2025-02-03/pretraining/14-19-06/version_0" # example
    ]
    for config_path in tqdm(config_paths):
        print(f"config path: {config_path}")
        checkpoint_dir = f"{config_path}/model_checkpoint/"

        # checkpoint 
        checkpoint_files = ["last.ckpt"] 

        for checkpoint_file in checkpoint_files:
            checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
            print(f"check point: {checkpoint_path}")
            with initialize(config_path, version_base='1.1'):
                cfg = compose(config_name="hparams")
                cfg.freeze_text_encoders = True

            
            # create datasetes
            model = ClaspModel.load_from_checkpoint(checkpoint_path, cfg=cfg, train_loader=None, val_loader=None)
            model.to(device)
            model.eval()

            tokenizer = BertTokenizer.from_pretrained(cfg.hf_textencoder_model_id)
            text_encoder = model.model_text.to(device)
            output_cry, output_text = predict_embeddings(model, dataloader, device)

            if (cfg.attention == "self") or cfg.attention == "scsc":
                auc_df, fig, roc_data = eval_roc_for_keyword_variations_self(df, output_cry, tokenizer, text_encoder, cfg, keyword_variations, train_keywords)
                save_filename_root = checkpoint_path.replace("../outputs/", "").replace("/","_").replace("model_checkpoint", "")
            
                plt.tight_layout() 
                fig.savefig(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_curve_keyword_variations_{target_dataset}_title.pdf')
                # plt.close()
                auc_df.to_csv(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_table_keyword_variations_{target_dataset}_title.csv', index=False)

                # save ROC
                with open(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_data_keyword_variations_{target_dataset}_title.pkl', 'wb') as f:
                    pickle.dump(roc_data, f)
            else:
                auc_df, fig, roc_data = eval_roc_for_keyword_variations(df, output_cry, tokenizer, text_encoder, cfg, keyword_variations)
                save_filename_root = checkpoint_path.replace("../outputs/", "").replace("/","_").replace("model_checkpoint", "")
            
                plt.tight_layout() 
                fig.savefig(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_curve_keyword_variations_{target_dataset}_title.pdf')
                # plt.close()
                auc_df.to_csv(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_table_keyword_variations_{target_dataset}_title.csv', index=False)

                # save ROC
                with open(f'/workspace/eval_results/{target_dataset}/{save_filename_root}_roc_data_keyword_variations_{target_dataset}_title.pkl', 'wb') as f:
                    pickle.dump(roc_data, f)


