import ujson as json
import numpy as np
import itertools
import math
import csv
import multiprocessing as mps
from functools import partial
import math
import os
import mpmath as mp

mp.dps = 1000

prefix_path = "../../data/results/csv/KGW-prob1-results"

# Function to calculate cosine similarity
def cosine_similarity(vec_a, vec_b):
    dot_product = np.dot(vec_a, vec_b)
    norm_a = np.linalg.norm(vec_a)
    norm_b = np.linalg.norm(vec_b)
    if norm_a == 0 and norm_b == 0:
        return 1
    elif norm_a == 0 or norm_b == 0:
        return 0
    return dot_product / (norm_a * norm_b)


def avg_cossim_tuple(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    return np.mean(similarities)


def variance_sample(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    if len(similarities) == 0:
        return 0
    return np.var(similarities, ddof=1)


def rank_invariant_transform(lst):
    sorted_indices = sorted(range(len(lst)), key=lambda i: lst[i], reverse=True)
    half_size = len(lst) // 2
    transformed = [0] * len(lst)
    for i in sorted_indices[:half_size]:
        transformed[i] = 1
    for i in sorted_indices[half_size:]:
        transformed[i] = -1
    return transformed


def assign_relative_rankings(probabilities, word_set):
    # Extract the probabilities corresponding to the word_set
    subset_probs = np.array([probabilities[idx] for idx in word_set])
    
    # Get the indices sorted in descending order
    sorted_indices = np.argsort(-subset_probs)  # -subset_probs represents descending order
    
    # Create an array of the same length as subset_probs to store rankings
    rankings = np.zeros_like(subset_probs, dtype=int)
    
    # Assign rankings based on the sorted indices
    for rank, idx in enumerate(sorted_indices):
        rankings[idx] = rank + 1  # Rankings start from 1
    
    # Build an array of the same shape as probabilities to store rankings for word_set
    result = np.zeros_like(probabilities, dtype=int)
    for i, idx in enumerate(word_set):
        result[idx] = rankings[i]
    
    return result

t1_wm = None
t2_wm = None
# Define the function to process a single pair
def process_pair(value):
    
    key1 = value[0]
    key2 = value[1]

    top_n = 50
    t1_words = {}
    t2_words = {}

    # Sort the pair and get the top_n
    sorted_index_t1_key1 = np.argsort(t1_wm[key1])[::-1]
    t1_words[key1] = sorted_index_t1_key1[:top_n]
    sorted_index_t2_key1 = np.argsort(t2_wm[key1])[::-1]
    t2_words[key1] = sorted_index_t2_key1[:top_n]

    sorted_index_t1_key2 = np.argsort(t1_wm[key2])[::-1]
    t1_words[key2] = sorted_index_t1_key2[:top_n]

    sorted_index_t2_key2 = np.argsort(t2_wm[key2])[::-1]
    t2_words[key2] = sorted_index_t2_key2[:top_n]

    # Get the intersection
    # (t11 ∪ t12) ∩ (t21 ∪ t22)
    word_set = (
        set(t1_words[key1]).union(set(t1_words[key2]))
    ).intersection(set(t2_words[key1]).union(set(t2_words[key2])))

    if word_set == set():
        zero_cnt += 1
        return (None, None)  

    # Normalize the probabilities
    t11_prob = t1_wm[key1] / np.sum(t1_wm[key1])
    t12_prob = t1_wm[key2] / np.sum(t1_wm[key2])
    t21_prob = t2_wm[key1] / np.sum(t2_wm[key1])
    t22_prob = t2_wm[key2] / np.sum(t2_wm[key2])

    

    if diff == "rank":
        t11_rank = assign_relative_rankings(t11_prob, word_set)
        t12_rank = assign_relative_rankings(t12_prob, word_set)
        t21_rank = assign_relative_rankings(t21_prob, word_set)
        t22_rank = assign_relative_rankings(t22_prob, word_set)

        t11_rank_word_set = [t11_rank[i] for i in word_set]
        t12_rank_word_set = [t12_rank[i] for i in word_set]
        t21_rank_word_set = [t21_rank[i] for i in word_set]
        t22_rank_word_set = [t22_rank[i] for i in word_set]
        t11_t12_rank_diff = [
            (
                1
                if t11_rank_word_set[i] > t12_rank_word_set[i]
                else -1 if t11_rank_word_set[i] < t12_rank_word_set[i] else 0
            )
            for i in range(len(t11_rank_word_set))
        ]
        t21_t22_rank_diff = [
            (
                1
                if t21_rank_word_set[i] > t22_rank_word_set[i]
                else -1 if t21_rank_word_set[i] < t22_rank_word_set[i] else 0
            )
            for i in range(len(t21_rank_word_set))
        ]
        
        return (key1, key2), (t11_t12_rank_diff, t21_t22_rank_diff)
    elif diff == "sub":
        t11_t12_diff = [
            t11_prob[i] - t12_prob[i] for i in word_set
        ]
        t21_t22_diff = [
            t21_prob[i] - t22_prob[i] for i in word_set
        ]
        return (key1, key2), (t11_t12_diff, t21_t22_diff)
    else:
        raise ValueError("Invalid diff type")
        
        
def run(
    model,
    num_samples,
    gamma,
    delta,
    prefix_length,
    scheme,
    fill_length,
    filter_threshold,
    combinations,
):
    global t1_wm, t2_wm
    # Batch processing

    for combo in combinations:
        temp, topp, topk = combo["temperature"], combo["topp"], combo["topk"]
        # json_file_name = f"{json_file_paths[idx]}-{model_name}-temp-{temperature}-{keylen}-topp-{top_p}-topk-{top_k}-{samples}.json"
        
        cossim_results = []

        for iter in range(3):
            print("Iteration: ", iter)
            p1_file = f"../../data/results/prob1/kgw-p1-{model}-{scheme}-temp-{temp}-topp-{topp}-topk-{topk}-gamma-{gamma}-delta-{delta}-prefixlen-{prefix_length}-{num_samples}-{fill_length}-iter-{iter}.json"
            p2_file = f"../../data/results/prob1/kgw-p2-{model}-{scheme}-temp-{temp}-topp-{topp}-topk-{topk}-gamma-{gamma}-delta-{delta}-prefixlen-{prefix_length}-{num_samples}-{fill_length}-iter-{iter}.json"
            
            print("loading samples...")
            # Load JSON data
            with open(p1_file, "r") as f:
                data_p1 = json.load(f)
            with open(p2_file, "r") as f:
                data_p2 = json.load(f)

            print("processing samples...")
            # Process watermark data
            t1_wm = {
                key: np.array(data_p1["watermarked"][key]["S_wm"])
                for key in data_p1["watermarked"].keys()
            }
            
            t2_wm = {
                key: np.array(data_p2["watermarked"][key]["S_wm"])
                for key in data_p2["watermarked"].keys()
            }

            # Filter
            t1_wm = {
                key: t1_wm[key]
                for key in t1_wm.keys()
                if np.sum(t1_wm[key]) >= filter_threshold
            }
            t2_wm = {
                key: t2_wm[key]
                for key in t2_wm.keys()
                if np.sum(t2_wm[key]) >= filter_threshold
            }

            wm_common_keys = set(t1_wm.keys()).intersection(set(t2_wm.keys()))
            t1_wm = {key: t1_wm[key] for key in wm_common_keys}
            t2_wm = {key: t2_wm[key] for key in wm_common_keys}
            wm_common_pairs = list(itertools.combinations(wm_common_keys, 2))

            print(f"Lenght of common pairs: {len(wm_common_pairs)}")
            t12_wm_diff = {}

            print("calculating differences...")
            
            with mps.Pool(processes=28) as pool:
                process_func = partial(process_pair)
                
                # Use map to process in parallel
                results = pool.map(process_func, wm_common_pairs)

            # Process the results
            for result in results:
                if result is not None:
                    key_pair, diff_data = result
                    t12_wm_diff[key_pair] = diff_data
            
            # Calculate metrics
            avg_cossim_wm = avg_cossim_tuple(t12_wm_diff)
            cossim_results.append(avg_cossim_wm)
            
            print(f"Average cosine similarity: {avg_cossim_wm}")
        # Get average cosine similarity
        cossim_results = np.array(cossim_results)
        cossim_mean = np.mean(cossim_results)
        
        # Get standard deviation
        cossim_std = np.std(cossim_results, ddof=1)
        
        # Get z-score for avg cossim = 0.1
        z_score = (cossim_mean - 0.1) / (cossim_std)
        # Convert z_score to mpf type
        z_mpf = mp.mpf(z_score)
        # Calculate one-sided p-value
        p_value = mp.erfc(z_mpf / mp.sqrt(2)) / 2
        p_value_str = mp.nstr(p_value, 50, strip_zeros=False)
    
        # Save results
        
        # CSV file header
        csv_header = [
            "model_name",
            "temperature",
            "topk",
            "topp",
            "pair_count",
            "cossim_wm",
            "avg_cossim_wm",
            "std_cossim_wm",
            "z_score_wm",
            "p_value_wm",
        ]
        
        output_path = f"{prefix_path}/KGW-{scheme}-{diff}-{model}-{temp}-{num_samples}-{prefix_length}-{gamma}-{delta}-{fill_length}-{filter_threshold}.csv"
        os.makedirs(prefix_path, exist_ok=True)
        with open(output_path, "w") as f:
            writer = csv.writer(f)
            writer.writerow(csv_header)
            
            for idx in range(3):
                writer.writerow(
                    [
                        model,
                        temp,
                        topk,
                        topp,
                        len(wm_common_pairs),
                        cossim_results[idx],
                        cossim_mean,
                        cossim_std,
                        z_score,
                        p_value_str,
                    ]
                )

        


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Run KGW experiments for prob1")
    parser.add_argument("--model_name", type=str, help="Model name")
    parser.add_argument("--samples", type=int, help="Number of samples")
    parser.add_argument("--threshold", type=float, help="Threshold")
    parser.add_argument("--gamma", default=0.5, type=float, help="Gamma")
    parser.add_argument("--delta", default=2.0, type=float, help="Delta")
    parser.add_argument("--prefix_length", default=4, type=int, help="Prefix length")
    parser.add_argument("--scheme", type=str, help="Scheme")
    parser.add_argument("--fill_length", default=50, type=int, help="Fill length")
    parser.add_argument("--combinations", type=str, help="Combinations")
    parser.add_argument("--diff", default="rank",type=str, help="Diff")

    args = parser.parse_args()
    print(args)
    model = args.model_name
    diff = args.diff
    
    if args.combinations == "temp":
        combinations = [
            {"temperature": 1.5, "topp": 1.0, "topk": 0},
            {"temperature": 1.4, "topp": 1.0, "topk": 0},
            {"temperature": 1.3, "topp": 1.0, "topk": 0},
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.1, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.9, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
            {"temperature": 0.7, "topp": 1.0, "topk": 0},
            {"temperature": 0.6, "topp": 1.0, "topk": 0},
            {"temperature": 0.5, "topp": 1.0, "topk": 0},
            {"temperature": 0.4, "topp": 1.0, "topk": 0},
            {"temperature": 0.3, "topp": 1.0, "topk": 0},
            {"temperature": 0.2, "topp": 1.0, "topk": 0},
            {"temperature": 0.1, "topp": 1.0, "topk": 0},
        ]
    elif args.combinations == "experiment":
        combinations = [
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
        ]

    run(
        model=model,
        num_samples=args.samples,
        gamma=args.gamma,
        delta=args.delta,
        prefix_length=args.prefix_length,
        fill_length=args.fill_length,
        scheme=args.scheme,
        filter_threshold=args.threshold,
        combinations=combinations,
    )
