import numpy as np
from sklearn.metrics import ndcg_score,dcg_score
from sentence_transformers import CrossEncoder
import re
import random
from utils import tokenize,fidelity_and_wfidelity,generate_original_ordering
from utils import mean_average_precision,mean_average_precision_graded

def load_ms_marco_data(n_queries,n_docs,file_path='dataset/top1000.dev'):
    query={}

    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            #cleaned_line = re.sub(r'[^a-zA-Z0-9 {}[\]":,.]', '', line.strip())
            split_line = line.split("\t")

            q = split_line[2]
            passage = split_line[3]

         
            if (q not in query):
                if (len(query) >= n_queries):
                    continue
                query[q] = [passage]
            elif (len(query[q]) <= n_docs):
                query[q].append(passage)
            
    return dict(query)

def prediction_function(model, query, docs):
    scores_arr = [(query, doc) for doc in docs]
    scores = model.predict(scores_arr)
    return np.array(scores)

# Function to remove all tokens not in a subset
def mask_text(text, allowed_tokens):
    tokens = tokenize(text)
    masked_text = ' '.join([token for token in tokens if token in allowed_tokens])
    return masked_text

#Compute value function score
def value_function(true_relevance, scores, val_func='NDCG' ):
    if val_func=='DCG':
        return dcg_score(true_relevance, scores)
    elif val_func=='binaryMAP':
        return mean_average_precision(true_relevance,scores)
    elif val_func=='MAP':
        return mean_average_precision_graded(true_relevance,scores)
    else:
        return ndcg_score(true_relevance, scores)

# Function to compute NDCG for a subset of tokens
def compute_valuefunction_for_subset(model, query, documents, labels, token_subset,val_func):
    allowed_tokens = set(token_subset)
    modified_query = mask_text(query, allowed_tokens)
    modified_documents = [mask_text(doc, allowed_tokens) for doc in documents]
    scores = prediction_function(model, modified_query,modified_documents)
    true_relevance = np.asarray([labels])
    return value_function(true_relevance, scores.reshape(1, -1),val_func)

# Kernel SHAP weighting function
def shap_kernel_weight(s, M):
    if s == 0 or s == M:
        return 0
    return (M - 1) / (s * (M - s))

# Approximate Kernel SHAP values
def approximate_shap_values(vocabulary, model, query, documents, relevance_labels, value_function = 'ndcg',num_samples=1000):
    shap_values = np.zeros(num_tokens)
    M = len(vocabulary)
    
    for _ in range(num_samples):

        # Randomly sample a coalition size
        subset_size = np.random.randint(0, M + 1)
        
        # Randomly sample a subset of this size
        subset = set(random.sample(vocabulary, subset_size))

        
        # Randomly select a token to compute its marginal contribution
        for token in vocabulary:
            if token in subset:
                subset_without_token = subset - {token}
                f_with = compute_valuefunction_for_subset(model, query, documents, relevance_labels, subset, value_function)
                f_without = compute_valuefunction_for_subset(model, query, documents, relevance_labels, subset_without_token,value_function)
                marginal_contribution = f_with - f_without
                
                # Weight the marginal contribution
                weight = shap_kernel_weight(len(subset), M)
                shap_values[vocabulary.index(token)] += weight * marginal_contribution
    
    # Normalize SHAP values
    shap_values /= num_samples
    return shap_values

#---------------------main----------
# Run the aproximate Kernel SHAP
num_samples = 1000  # Adjust for better accuracy
val_func= 'MAP'
print('HELLO WORLD')

num_queries = 5
num_documents = 10
dataset_file_path = '../rankllama/dataset/top1000.dev'
q_dict = load_ms_marco_data(num_queries,num_documents,dataset_file_path) 

# Neural Ranking Model/LLM
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
for query in q_dict:
    documents = q_dict[query]

    # Combine tokens from query and documents into a single vocabulary. These tokens are the features for which we seek attributions.
    all_texts = [query] + documents
    vocabulary = list(set(token for text in all_texts for token in tokenize(text)))
    num_tokens = len(vocabulary)


    relevance_labels = prediction_function(model, query, documents)
    relevance_labels[relevance_labels < 0] = 0
    shap_values = approximate_shap_values(vocabulary, model, query, documents, relevance_labels, value_function=val_func, num_samples=num_samples)
    sorted_shap_values = sorted(zip(vocabulary, shap_values), key=lambda x: -abs(x[1]))
    # print(dict(sorted_shap_values)) 

    fidelity, wFidelity = fidelity_and_wfidelity(query, documents, model, generate_original_ordering(query, documents,model, prediction_function),dict(sorted_shap_values))
    print('Eval',fidelity,wFidelity)