from transformers import AutoTokenizer, AutoModel
import torch
import os
import pickle
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score

text_dict = {
    'COLLAB': ['High Energy Physics', 'Condensed Matter Physics ', 'Astro Physics'], 
    'REDDIT-BINARY': ['question/answer-based community', 'discussion-based community'], 
    'IMDB-BINARY': ['Action', 'Romance'], 
    'IMDB-MULTI': ['Comedy', 'Romance', 'Sci-Fi'], 
    'Letter-med': ['A', 'E', 'F', 'H', 'I', 'K', 'L', 'M', 'N', 'T', 'V', 'W', 'X', 'Y', 'Z'], 
    'COIL-RAG': [str(i) for i in range(100)], 
    'Cuneiform': [str(i) for i in range(30)]

}

def save_class_embed(dataset_name, text_dict, model_path):

    # model_name = "./e5-base"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path)


    texts_list = text_dict[dataset_name]
    embed_list = []
    label_list = list(range(len(texts_list)))

    for texts in texts_list:

        inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

        with torch.no_grad():
            outputs = model(**inputs)
            embeddings = outputs.last_hidden_state[:, 0]  

        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        embed_list.append(embeddings)
    embed_dict = {label: embed for label, embed in zip(label_list, embed_list)}
    print(embed_dict)
    
    save_dir = './class_embed'
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    with open(os.path.join(save_dir, f'{dataset_name}_class_embeds.pkl'), 'wb') as f:
        pickle.dump(embed_dict, f)


def load_class_embed(save_dir, dataset_name):
    with open(os.path.join(save_dir, f'{dataset_name}_class_embeds.pkl'), 'rb') as f:
        embed_dict = pickle.load(f)
    return embed_dict

import numpy as np
import torch.nn.functional as F
def reg(graph_embed, class_embed, lam=1.):
    inv = torch.linalg.inv(graph_embed.T@graph_embed + lam*torch.eye(graph_embed.shape[1])) # [graph_dim, graph_dim]
    XTy = graph_embed.T @ class_embed   # [graph_dim, cls_dim]
    return inv @ XTy    # [graph_dim, cls_dim]


def reg_loss(cls_embed, graph_embed, W, lambda_=1.):
    graph_embedW = graph_embed @ W # [n_graphs, cls_dim]
    mse = torch.cdist(graph_embedW, cls_embed, p=2)**2
    reg = lambda_ * torch.norm(W, p='fro')**2
    return mse + reg


def zero_shot_predict(graph_embed, y_true, dataset_name, save_dir = './class_embed'):
    class_embed_dict = load_class_embed(save_dir, dataset_name)
    n_class = len(class_embed_dict)
    class_embed = torch.cat([class_embed_dict[i] for i in range(n_class)])

    ytrue_embeds = [class_embed_dict[label.item()] for label in y_true]
    ytrue_embeds = torch.cat(ytrue_embeds, dim=0)
    W = reg(graph_embed, ytrue_embeds)  # [graph_dim, cls_dim]

    loss_mtx = reg_loss(class_embed, graph_embed, W)
    pred = torch.argmin(loss_mtx, dim=-1)
    return pred
    
    
def batch_zero_shot(graph_Z, y, dataset_name, batch_size=32, save_dir = './class_embed'):
    dataset = TensorDataset(graph_Z, y)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    all_pred = []
    for graph_embeds, y_true in dataloader:
        batch_pred = zero_shot_predict(graph_embeds, y_true, dataset_name=dataset_name)
        all_pred.append(batch_pred)
    all_pred = torch.cat(all_pred)
    acc = accuracy_score(y, all_pred)
    return all_pred, acc


if __name__ == '__main__':
    names = [
        # 'COLLAB', 
        # 'REDDIT-BINARY', 
        # 'IMDB-BINARY', 
        # 'IMDB-MULTI', 
        # 'Letter-med', 
        'COIL-RAG', 
        'Cuneiform'
    ]
    for name in names:
        save_class_embed(name, text_dict, model_path='./e5-base/e5-small-v2')