import ujson as json
import numpy as np
import itertools
import math
import csv
import multiprocessing as mp
from functools import partial
import math
import os

prefix_path = "../../data/results/csv/KGW-Fix-prob1-results"

# Function to calculate cosine similarity
def cosine_similarity(vec_a, vec_b):
    dot_product = np.dot(vec_a, vec_b)
    norm_a = np.linalg.norm(vec_a)
    norm_b = np.linalg.norm(vec_b)
    if norm_a == 0 and norm_b == 0:
        return 1
    elif norm_a == 0 or norm_b == 0:
        return 0
    return dot_product / (norm_a * norm_b)


def avg_cossim_tuple(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    return np.mean(similarities)


def variance_sample(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    if len(similarities) == 0:
        return 0
    return np.var(similarities, ddof=1)


def rank_invariant_transform(lst):
    sorted_indices = sorted(range(len(lst)), key=lambda i: lst[i], reverse=True)
    half_size = len(lst) // 2
    transformed = [0] * len(lst)
    for i in sorted_indices[:half_size]:
        transformed[i] = 1
    for i in sorted_indices[half_size:]:
        transformed[i] = -1
    return transformed


def assign_relative_rankings(probabilities, word_set):
    # Extract the probabilities corresponding to word_set
    subset_probs = np.array([probabilities[idx] for idx in word_set])
    
    # Get the indices sorted in descending order
    sorted_indices = np.argsort(-subset_probs) 

    # Create an array of the same length as subset_probs to store rankings
    rankings = np.zeros_like(subset_probs, dtype=int)
    
    # Assign rankings based on the sorted indices
    for rank, idx in enumerate(sorted_indices):
        rankings[idx] = rank + 1  # Rank starts from 1
    
    # Build an array of the same shape as probabilities to store rankings for word_set
    result = np.zeros_like(probabilities, dtype=int)
    for i, idx in enumerate(word_set):
        result[idx] = rankings[i]
    
    return result

t1_wm = None
t2_wm = None
# Define the function to process a single pair
def process_pair(value):
    
    key1 = value[0]
    key2 = value[1]

    top_n = 50
    t1_words = {}
    t2_words = {}

    # Sort the pair and get the top_n
    sorted_index_t1_key1 = np.argsort(t1_wm[key1])[::-1]
    t1_words[key1] = sorted_index_t1_key1[:top_n]
    sorted_index_t2_key1 = np.argsort(t2_wm[key1])[::-1]
    t2_words[key1] = sorted_index_t2_key1[:top_n]

    sorted_index_t1_key2 = np.argsort(t1_wm[key2])[::-1]
    t1_words[key2] = sorted_index_t1_key2[:top_n]

    sorted_index_t2_key2 = np.argsort(t2_wm[key2])[::-1]
    t2_words[key2] = sorted_index_t2_key2[:top_n]

    # Get the intersection
    # (t11 ∪ t12) ∩ (t21 ∪ t22)
    word_set = (
        set(t1_words[key1]).union(set(t1_words[key2]))
    ).intersection(set(t2_words[key1]).union(set(t2_words[key2])))

    if word_set == set():
        zero_cnt += 1
        return (None, None) 

    # Normalize
    t11_prob = t1_wm[key1] / np.sum(t1_wm[key1])
    t12_prob = t1_wm[key2] / np.sum(t1_wm[key2])
    t21_prob = t2_wm[key1] / np.sum(t2_wm[key1])
    t22_prob = t2_wm[key2] / np.sum(t2_wm[key2])

    t11_rank = assign_relative_rankings(t11_prob, word_set)
    t12_rank = assign_relative_rankings(t12_prob, word_set)
    t21_rank = assign_relative_rankings(t21_prob, word_set)
    t22_rank = assign_relative_rankings(t22_prob, word_set)

    t11_rank_word_set = [t11_rank[i] for i in word_set]
    t12_rank_word_set = [t12_rank[i] for i in word_set]
    t21_rank_word_set = [t21_rank[i] for i in word_set]
    t22_rank_word_set = [t22_rank[i] for i in word_set]

    t11_t12_rank_diff = [
        (
            1
            if t11_rank_word_set[i] > t12_rank_word_set[i]
            else -1 if t11_rank_word_set[i] < t12_rank_word_set[i] else 0
        )
        for i in range(len(t11_rank_word_set))
    ]
    t21_t22_rank_diff = [
        (
            1
            if t21_rank_word_set[i] > t22_rank_word_set[i]
            else -1 if t21_rank_word_set[i] < t22_rank_word_set[i] else 0
        )
        for i in range(len(t21_rank_word_set))
    ]

    # Return the calculated results
    return (key1, key2), (t11_t12_rank_diff, t21_t22_rank_diff)

def run(
    model,
    num_samples,
    gamma,
    delta,
    prefix_length,
    keylen,
    filter_threshold,
):
    global t1_wm, t2_wm
    # Batch processing

    # for combo in combinations:
    temp, topp, topk = 1.0, 1.0, 0
    # json_file_name = f"{json_file_paths[idx]}-{model_name}-temp-{temperature}-{keylen}-topp-{top_p}-topk-{top_k}-{samples}.json"
    
    cossim_results = []

    for iter in range(3):
        print("Iteration: ", iter)
        p1_file = f"../../data/results/prob1/kgwfix-p1-{model}-temp-{temp}-kgwfix-topp-{topp}-topk-{topk}-gamma-{gamma}-delta-{delta}-prefixlen-{prefix_length}-{num_samples}-{keylen}-prob1-iter-{iter}.json"
        p2_file = f"../../data/results/prob1/kgwfix-p2-{model}-temp-{temp}-kgwfix-topp-{topp}-topk-{topk}-gamma-{gamma}-delta-{delta}-prefixlen-{prefix_length}-{num_samples}-{keylen}-prob1-iter-{iter}.json"
        
        print("loading samples...")
        # Load JSON data
        with open(p1_file, "r") as f:
            data_p1 = json.load(f)
        with open(p2_file, "r") as f:
            data_p2 = json.load(f)

        print("processing samples...")
        # Process watermark data
        t1_wm = {
            key: np.array(data_p1["watermarked"][key]["S_wm"])
            for key in data_p1["watermarked"].keys()
        }
        
        t2_wm = {
            key: np.array(data_p2["watermarked"][key]["S_wm"])
            for key in data_p2["watermarked"].keys()
        }

        # Filter
        t1_wm = {
            key: t1_wm[key]
            for key in t1_wm.keys()
            if np.sum(t1_wm[key]) >= filter_threshold
        }
        t2_wm = {
            key: t2_wm[key]
            for key in t2_wm.keys()
            if np.sum(t2_wm[key]) >= filter_threshold
        }

        wm_common_keys = set(t1_wm.keys()).intersection(set(t2_wm.keys()))
        t1_wm = {key: t1_wm[key] for key in wm_common_keys}
        t2_wm = {key: t2_wm[key] for key in wm_common_keys}
        wm_common_pairs = list(itertools.combinations(wm_common_keys, 2))

        print(f"Lenght of common pairs: {len(wm_common_pairs)}")
        t12_wm_diff = {}

        print("calculating differences...")
        
        with mp.Pool(processes=28) as pool:
            process_func = partial(process_pair)
            
            # 
            results = pool.map(process_func, wm_common_pairs)

        #
        for result in results:
            if result is not None:
                key_pair, diff_data = result
                t12_wm_diff[key_pair] = diff_data
        
        # Calculate metrics
        avg_cossim_wm = avg_cossim_tuple(t12_wm_diff)
        cossim_results.append(avg_cossim_wm)
        
        print(f"Average cosine similarity: {avg_cossim_wm}")
    # Get average cosine similarity
    cossim_results = np.array(cossim_results)
    cossim_mean = np.mean(cossim_results)
    
    # Get standard deviation
    cossim_std = np.std(cossim_results, ddof=1)
    
    # Get z-score for avg cossim = 0.1
    z_score = (cossim_mean - 0.1) / (cossim_std)
    
    # Save results
    
    # CSV file header
    csv_header = [
        "model_name",
        "temperature",
        "topk",
        "topp",
        "pair_count",
        "cossim_wm",
        "avg_cossim_wm",
        "std_cossim_wm",
        "z_score_wm",
    ]
    
    output_path = f"{prefix_path}/KGWFix-{keylen}-{model}-{num_samples}-{prefix_length}-{gamma}-{delta}-50-{filter_threshold}.csv"
    os.makedirs(prefix_path, exist_ok=True)
    with open(output_path, "w") as f:
        writer = csv.writer(f)
        writer.writerow(csv_header)
        
        for idx in range(3):
            writer.writerow(
                [
                    model,
                    temp,
                    topk,
                    topp,
                    len(wm_common_pairs),
                    cossim_results[idx],
                    cossim_mean,
                    cossim_std,
                    z_score,
                ]
            )

        


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Run KGW-Fix experiments")
    parser.add_argument("--model_name", default="llama-2-7b-hf", type=str, help="Model name list")
    parser.add_argument("--samples", type=int, help="Parameter combination list")
    parser.add_argument("--threshold", default=20, type=float, help="Parameter combination list")
    parser.add_argument("--gamma", default=0.5, type=float, help="Parameter combination list")
    parser.add_argument("--delta", default=2.0, type=float, help="Parameter combination list")
    parser.add_argument("--prefix_length", default=4, type=int, help="Parameter combination list")
    parser.add_argument("--keylen", type=int, help="Parameter combination list")

    args = parser.parse_args()
    print(args)
    model = args.model_name

    run(
        model=model,
        num_samples=args.samples,
        gamma=args.gamma,
        delta=args.delta,
        prefix_length=args.prefix_length,
        keylen=args.keylen,
        filter_threshold=args.threshold,
    )
