import ujson as json
import numpy as np
import itertools
import math
import csv
import multiprocessing as mps
from functools import partial
import math
import os
import mpmath as mp
mp.dps = 1000

prefix_path = "../../data/results/csv/Aar-prob2-results"

# Function to calculate cosine similarity
def cosine_similarity(vec_a, vec_b):
    dot_product = np.dot(vec_a, vec_b)
    norm_a = np.linalg.norm(vec_a)
    norm_b = np.linalg.norm(vec_b)
    if norm_a == 0 and norm_b == 0:
        return 1
    elif norm_a == 0 or norm_b == 0:
        return 0
    return dot_product / (norm_a * norm_b)


def avg_cossim_tuple(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    return np.mean(similarities)


def variance_sample(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    if len(similarities) == 0:
        return 0
    return np.var(similarities, ddof=1)


def rank_invariant_transform(lst):
    sorted_indices = sorted(range(len(lst)), key=lambda i: lst[i], reverse=True)
    half_size = len(lst) // 2
    transformed = [0] * len(lst)
    for i in sorted_indices[:half_size]:
        transformed[i] = 1
    for i in sorted_indices[half_size:]:
        transformed[i] = -1
    return transformed

def assign_relative_rankings(probabilities, word_set):
    # Only extract the probabilities of the word_set
    subset_probs = np.array([probabilities[idx] for idx in word_set])
    
    # Get the indices of the descending order
    sorted_indices = np.argsort(-subset_probs)
    
    # Create an array of the same length as subset_probs to store rankings
    rankings = np.zeros_like(subset_probs, dtype=int)
    
    # Assign rankings based on the sorted indices
    for rank, idx in enumerate(sorted_indices):
        rankings[idx] = rank + 1  # Rankings start from 1
    
    # Build an array of the same shape as probabilities to store rankings of word_set
    result = np.zeros_like(probabilities, dtype=int)
    for i, idx in enumerate(word_set):
        result[idx] = rankings[i]
    
    return result

t1_wm = None
t2_wm = None
# Define the function to process a single pair
def process_pair(value):
    
    key1 = value[0]
    key2 = value[1]

    top_n = 50
    t1_words = {}
    t2_words = {}

    # Sort the pair and get top_n
    sorted_index_t1_key1 = np.argsort(t1_wm[key1])[::-1]
    t1_words[key1] = sorted_index_t1_key1[:top_n]
    sorted_index_t2_key1 = np.argsort(t2_wm[key1])[::-1]
    t2_words[key1] = sorted_index_t2_key1[:top_n]

    sorted_index_t1_key2 = np.argsort(t1_wm[key2])[::-1]
    t1_words[key2] = sorted_index_t1_key2[:top_n]

    sorted_index_t2_key2 = np.argsort(t2_wm[key2])[::-1]
    t2_words[key2] = sorted_index_t2_key2[:top_n]

    # Get the intersection
    # (t11 ∪ t12) ∩ (t21 ∪ t22)
    word_set = (
        set(t1_words[key1]).union(set(t1_words[key2]))
    ).intersection(set(t2_words[key1]).union(set(t2_words[key2])))

    if word_set == set():
        zero_cnt += 1
        return (None, None)

    # Normalize the probabilities
    t11_prob = t1_wm[key1] / np.sum(t1_wm[key1])
    t12_prob = t1_wm[key2] / np.sum(t1_wm[key2])
    t21_prob = t2_wm[key1] / np.sum(t2_wm[key1])
    t22_prob = t2_wm[key2] / np.sum(t2_wm[key2])

    t11_rank = assign_relative_rankings(t11_prob, word_set)
    t12_rank = assign_relative_rankings(t12_prob, word_set)
    t21_rank = assign_relative_rankings(t21_prob, word_set)
    t22_rank = assign_relative_rankings(t22_prob, word_set)

    t11_rank_word_set = [t11_rank[i] for i in word_set]
    t12_rank_word_set = [t12_rank[i] for i in word_set]
    t21_rank_word_set = [t21_rank[i] for i in word_set]
    t22_rank_word_set = [t22_rank[i] for i in word_set]

    t11_t12_rank_diff = [
        (
            1
            if t11_rank_word_set[i] > t12_rank_word_set[i]
            else -1 if t11_rank_word_set[i] < t12_rank_word_set[i] else 0
        )
        for i in range(len(t11_rank_word_set))
    ]
    t21_t22_rank_diff = [
        (
            1
            if t21_rank_word_set[i] > t22_rank_word_set[i]
            else -1 if t21_rank_word_set[i] < t22_rank_word_set[i] else 0
        )
        for i in range(len(t21_rank_word_set))
    ]

    # Return the calculated results
    return (key1, key2), (t11_t12_rank_diff, t21_t22_rank_diff)

def run(
    model,
    num_samples,
    prefix_length,
    filter_threshold,
):
    global t1_wm, t2_wm
    # Batch processing

    # for combo in combinations:
    temp, topp, topk = 1.0, 1.0, 0
    
    cossim_results = []

    for iter in range(3):
        print("Iteration: ", iter)
        p1_file = f"../../data/results/prob2/aar-p1-{model}-temp-{temp}-topp-{topp}-topk-{topk}-prefixlen-{prefix_length}-{num_samples}-prob2-iter-{iter}.json"
        p2_file = f"../../data/results/prob2/aar-p2-{model}-temp-{temp}-topp-{topp}-topk-{topk}-prefixlen-{prefix_length}-{num_samples}-prob2-iter-{iter}.json"
        
        print("loading samples...")
        # Load JSON data
        with open(p1_file, "r") as f:
            data_p1 = json.load(f)
        with open(p2_file, "r") as f:
            data_p2 = json.load(f)

        print("processing samples...")
        # Process watermark data
        t1_wm = {
            key: np.array(data_p1["watermarked"][key]["S_wm"])
            for key in data_p1["watermarked"].keys()
        }
        
        t2_wm = {
            key: np.array(data_p2["watermarked"][key]["S_wm"])
            for key in data_p2["watermarked"].keys()
        }

        # Filter
        t1_wm = {
            key: t1_wm[key]
            for key in t1_wm.keys()
            if np.sum(t1_wm[key])  >= filter_threshold
        }
        t2_wm = {
            key: t2_wm[key]
            for key in t2_wm.keys()
            if np.sum(t2_wm[key])  >= filter_threshold
        }

        wm_common_keys = set(t1_wm.keys()).intersection(set(t2_wm.keys()))
        t1_wm = {key: t1_wm[key] for key in wm_common_keys}
        t2_wm = {key: t2_wm[key] for key in wm_common_keys}
        wm_common_pairs = list(itertools.combinations(wm_common_keys, 2))

        print(f"Lenght of common pairs: {len(wm_common_pairs)}")
        t12_wm_diff = {}

        print("calculating differences...")
        
        with mps.Pool(processes=28) as pool:
            process_func = partial(process_pair)
            
            # Use map to process in parallel
            results = pool.map(process_func, wm_common_pairs)

        # Process the results
        for result in results:
            if result is not None:
                key_pair, diff_data = result
                t12_wm_diff[key_pair] = diff_data
        
        # Calculate metrics
        avg_cossim_wm = avg_cossim_tuple(t12_wm_diff)
        cossim_results.append(avg_cossim_wm)
        
        print(f"Average cosine similarity: {avg_cossim_wm}")
    # Get average cosine similarity
    cossim_results = np.array(cossim_results)
    cossim_mean = np.mean(cossim_results)
    
    # Get standard deviation
    cossim_std = np.std(cossim_results, ddof=1)
    
    # Get z-score for avg cossim = 0.1
    z_score = (cossim_mean - 0.1) / (cossim_std)
    z_mpf = mp.mpf(z_score)
    # Calculate one-sided p-value
    p_value = mp.erfc(z_mpf / mp.sqrt(2)) / 2
    p_value_str = mp.nstr(p_value, 50, strip_zeros=False)
    
    output_path = f"{prefix_path}/Aar-prob2-{model}-{num_samples}-{prefix_length}-{filter_threshold}.csv"
    os.makedirs(prefix_path, exist_ok=True)
    # Save results
    # CSV file header
    csv_header = [
        "model_name",
        "temperature",
        "topk",
        "topp",
        "pair_count",
        "cossim_wm",
        "avg_cossim_wm",
        "std_cossim_wm",
        "z_score_wm",
        "p_value_wm",
    ]
    with open(output_path, "w") as f:
        writer = csv.writer(f)
        writer.writerow(csv_header)
        
        for idx in range(3):
            writer.writerow(
                [
                    model,
                    temp,
                    topk,
                    topp,
                    len(wm_common_pairs),
                    cossim_results[idx],
                    cossim_mean,
                    cossim_std,
                    z_score,
                    p_value_str,
                ]
            )

        


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Run the script with parameters")
    parser.add_argument("--model_name", type=str, help="Model name list")
    parser.add_argument("--samples", type=int, help="Parameter combination list")
    parser.add_argument("--threshold", type=float, help="Threshold for filtering")
    parser.add_argument("--prefix_length", type=int, help="Prefix length")
    
    args = parser.parse_args()
    print(args)
    model = args.model_name

    run(
        model=model,
        num_samples=args.samples,
        prefix_length=args.prefix_length,
        filter_threshold=args.threshold,
    )
