import ujson as json
import numpy as np
import itertools
import math
import csv
import multiprocessing as mps
from functools import partial
import math
import os
import mpmath as mp

mp.dps = 1000

prefix_path = "../../data/results/csv/Aar-prob1-results"

# Function to calculate cosine similarity
def cosine_similarity(vec_a, vec_b):
    dot_product = np.dot(vec_a, vec_b)
    norm_a = np.linalg.norm(vec_a)
    norm_b = np.linalg.norm(vec_b)
    if norm_a == 0 and norm_b == 0:
        return 1
    elif norm_a == 0 or norm_b == 0:
        return 0
    return dot_product / (norm_a * norm_b)


def avg_cossim_tuple(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    return np.mean(similarities)


def variance_sample(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    if len(similarities) == 0:
        return 0
    return np.var(similarities, ddof=1)


def rank_invariant_transform(lst):
    sorted_indices = sorted(range(len(lst)), key=lambda i: lst[i], reverse=True)
    half_size = len(lst) // 2
    transformed = [0] * len(lst)
    for i in sorted_indices[:half_size]:
        transformed[i] = 1
    for i in sorted_indices[half_size:]:
        transformed[i] = -1
    return transformed

def assign_relative_rankings(probabilities, word_set):
    # Only consider the probabilities of the words in word_set
    subset_probs = np.array([probabilities[idx] for idx in word_set])
    
    # Get the indices sorted in descending order
    sorted_indices = np.argsort(-subset_probs)  
    
    # Create an array of the same length as subset_probs to store rankings
    rankings = np.zeros_like(subset_probs, dtype=int)
    
    # Assign rankings based on the sorted indices
    for rank, idx in enumerate(sorted_indices):
        rankings[idx] = rank + 1  # Rankings start from 1
    
    # Build an array of the same shape as probabilities to store rankings for word_set
    result = np.zeros_like(probabilities, dtype=int)
    for i, idx in enumerate(word_set):
        result[idx] = rankings[i]
    
    return result

t1_wm = None
t2_wm = None
# Define a function to process a single pair
def process_pair(value):
    
    key1 = value[0]
    key2 = value[1]

    top_n = 50
    t1_words = {}
    t2_words = {}

    # Sort the pair and get the top_n
    sorted_index_t1_key1 = np.argsort(t1_wm[key1])[::-1]
    t1_words[key1] = sorted_index_t1_key1[:top_n]
    sorted_index_t2_key1 = np.argsort(t2_wm[key1])[::-1]
    t2_words[key1] = sorted_index_t2_key1[:top_n]

    sorted_index_t1_key2 = np.argsort(t1_wm[key2])[::-1]
    t1_words[key2] = sorted_index_t1_key2[:top_n]

    sorted_index_t2_key2 = np.argsort(t2_wm[key2])[::-1]
    t2_words[key2] = sorted_index_t2_key2[:top_n]

    # Get the intersection
    # (t11 ∪ t12) ∩ (t21 ∪ t22)
    word_set = (
        set(t1_words[key1]).union(set(t1_words[key2]))
    ).intersection(set(t2_words[key1]).union(set(t2_words[key2])))

    if word_set == set():
        zero_cnt += 1
        return (None, None)  # Word set is empty

    # Normalization
    t11_prob = t1_wm[key1] / np.sum(t1_wm[key1])
    t12_prob = t1_wm[key2] / np.sum(t1_wm[key2])
    t21_prob = t2_wm[key1] / np.sum(t2_wm[key1])
    t22_prob = t2_wm[key2] / np.sum(t2_wm[key2])

    t11_rank = assign_relative_rankings(t11_prob, word_set)
    t12_rank = assign_relative_rankings(t12_prob, word_set)
    t21_rank = assign_relative_rankings(t21_prob, word_set)
    t22_rank = assign_relative_rankings(t22_prob, word_set)

    t11_rank_word_set = [t11_rank[i] for i in word_set]
    t12_rank_word_set = [t12_rank[i] for i in word_set]
    t21_rank_word_set = [t21_rank[i] for i in word_set]
    t22_rank_word_set = [t22_rank[i] for i in word_set]

    t11_t12_rank_diff = [
        (
            1
            if t11_rank_word_set[i] > t12_rank_word_set[i]
            else -1 if t11_rank_word_set[i] < t12_rank_word_set[i] else 0
        )
        for i in range(len(t11_rank_word_set))
    ]
    t21_t22_rank_diff = [
        (
            1
            if t21_rank_word_set[i] > t22_rank_word_set[i]
            else -1 if t21_rank_word_set[i] < t22_rank_word_set[i] else 0
        )
        for i in range(len(t21_rank_word_set))
    ]

    # Return the calculation results
    return (key1, key2), (t11_t12_rank_diff, t21_t22_rank_diff)

def run(
    model,
    num_samples,
    prefix_length,
    fill_length,
    filter_threshold,
):
    global t1_wm, t2_wm
    # Batch processing

    # for combo in combinations:
    temp, topp, topk = 1.0, 1.0, 0
    # json_file_name = f"{json_file_paths[idx]}-{model_name}-temp-{temperature}-{keylen}-topp-{top_p}-topk-{top_k}-{samples}.json"
    
    cossim_results = []

    for iter in range(3):
        print("Iteration: ", iter)
        p1_file = f"../../data/results/prob1/aar-p1-{model}-temp-{temp}-topp-{topp}-topk-{topk}-prefixlen-{prefix_length}-{num_samples}-{fill_length}-iter-{iter}.json"
        p2_file = f"../../data/results/prob1/aar-p2-{model}-temp-{temp}-topp-{topp}-topk-{topk}-prefixlen-{prefix_length}-{num_samples}-{fill_length}-iter-{iter}.json"
        
        print("loading samples...")
        # Load JSON data
        with open(p1_file, "r") as f:
            data_p1 = json.load(f)
        with open(p2_file, "r") as f:
            data_p2 = json.load(f)

        print("processing samples...")
        # Process watermark data
        t1_wm = {
            key: np.array(data_p1["watermarked"][key]["S_wm"])
            for key in data_p1["watermarked"].keys()
        }
        
        t2_wm = {
            key: np.array(data_p2["watermarked"][key]["S_wm"])
            for key in data_p2["watermarked"].keys()
        }

        # Filter
        t1_wm = {
            key: t1_wm[key]
            for key in t1_wm.keys()
            if np.sum(t1_wm[key]) >= filter_threshold
        }
        t2_wm = {
            key: t2_wm[key]
            for key in t2_wm.keys()
            if np.sum(t2_wm[key]) >= filter_threshold
        }

        wm_common_keys = set(t1_wm.keys()).intersection(set(t2_wm.keys()))
        t1_wm = {key: t1_wm[key] for key in wm_common_keys}
        t2_wm = {key: t2_wm[key] for key in wm_common_keys}
        wm_common_pairs = list(itertools.combinations(wm_common_keys, 2))

        print(f"Lenght of common pairs: {len(wm_common_pairs)}")
        t12_wm_diff = {}

        print("calculating differences...")
        
        with mps.Pool(processes=28) as pool:
            process_func = partial(process_pair)
            
            # Parallel processing using map
            results = pool.map(process_func, wm_common_pairs)

        # Process results
        for result in results:
            if result is not None:
                key_pair, diff_data = result
                t12_wm_diff[key_pair] = diff_data
        
        # Calculate metrics
        avg_cossim_wm = avg_cossim_tuple(t12_wm_diff)
        cossim_results.append(avg_cossim_wm)
        
        print(f"Average cosine similarity: {avg_cossim_wm}")
    # Get average cosine similarity
    cossim_results = np.array(cossim_results)
    cossim_mean = np.mean(cossim_results)
    
    # Get standard deviation
    cossim_std = np.std(cossim_results, ddof=1)
    
    # Get z-score for avg cossim = 0.1
    z_score = (cossim_mean - 0.1) / (cossim_std)
    
    # Convert z_score to mpmath's mpf type
    z_mpf = mp.mpf(z_score)
    # Calculate one-sided p-value
    p_value = mp.erfc(z_mpf / mp.sqrt(2)) / 2
    p_value_str = mp.nstr(p_value, 50, strip_zeros=False)

    
    output_path = f"{prefix_path}/Aar-prob1-{model}-{num_samples}-{prefix_length}-{fill_length}-{filter_threshold}.csv"
    os.makedirs(prefix_path, exist_ok=True)
    # Save results
    # CSV file header
    csv_header = [
        "model_name",
        "temperature",
        "topk",
        "topp",
        "pair_count",
        "cossim_wm",
        "avg_cossim_wm",
        "std_cossim_wm",
        "z_score_wm",
        "p_value_wm",
    ]
    with open(output_path, "w") as f:
        writer = csv.writer(f)
        writer.writerow(csv_header)
        
        for idx in range(3):
            writer.writerow(
                [
                    model,
                    temp,
                    topk,
                    topp,
                    len(wm_common_pairs),
                    cossim_results[idx],
                    cossim_mean,
                    cossim_std,
                    z_score,
                    p_value_str,
                ]
            )

        


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Run script with parameters")
    parser.add_argument("--model_name", type=str, help="Model name list")
    parser.add_argument("--samples", type=int, help="Parameter combination list")
    parser.add_argument("--threshold", type=float, help="Threshold")
    parser.add_argument("--prefix_length", type=int, help="Prefix length")
    parser.add_argument("--fill_length", type=int, help="Fill length")

    args = parser.parse_args()
    print(args)
    model = args.model_name

    run(
        model=model,
        num_samples=args.samples,
        prefix_length=args.prefix_length,
        fill_length=args.fill_length,
        filter_threshold=args.threshold,
    )
