import ujson as json
import numpy as np
import itertools
import math
import csv
import multiprocessing as mps
from functools import partial
import math
import os
import mpmath as mp

mp.dps = 1000
prefix_path = "../../data/results/csv/DIP-prob2-results"

# Function to calculate cosine similarity
def cosine_similarity(vec_a, vec_b):
    dot_product = np.dot(vec_a, vec_b)
    norm_a = np.linalg.norm(vec_a)
    norm_b = np.linalg.norm(vec_b)
    if norm_a == 0 and norm_b == 0:
        return 1
    elif norm_a == 0 or norm_b == 0:
        return 0
    return dot_product / (norm_a * norm_b)


def avg_cossim_tuple(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    return np.mean(similarities)


def variance_sample(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    if len(similarities) == 0:
        return 0
    return np.var(similarities, ddof=1)


def rank_invariant_transform(lst):
    sorted_indices = sorted(range(len(lst)), key=lambda i: lst[i], reverse=True)
    half_size = len(lst) // 2
    transformed = [0] * len(lst)
    for i in sorted_indices[:half_size]:
        transformed[i] = 1
    for i in sorted_indices[half_size:]:
        transformed[i] = -1
    return transformed


def assign_relative_rankings(probabilities, word_set):
    # Only extract probabilities for word_set
    subset_probs = np.array([probabilities[idx] for idx in word_set])
    
    # Get indices sorted in descending order
    sorted_indices = np.argsort(-subset_probs) 
    
    # Create an array of the same length as subset_probs to store rankings
    rankings = np.zeros_like(subset_probs, dtype=int)
    
    # Assign rankings based on sorted indices
    for rank, idx in enumerate(sorted_indices):
        rankings[idx] = rank + 1  # Rank starts from 1
    
    # Build an array of the same shape as probabilities to store rankings for word_set
    result = np.zeros_like(probabilities, dtype=int)
    for i, idx in enumerate(word_set):
        result[idx] = rankings[i]
    
    return result

t1_wm = None
t2_wm = None
# Define function to process a single pair
def process_pair(value):
    
    key1 = value[0]
    key2 = value[1]

    top_n = 50
    t1_words = {}
    t2_words = {}

    # Sort pair and get top_n
    sorted_index_t1_key1 = np.argsort(t1_wm[key1])[::-1]
    t1_words[key1] = sorted_index_t1_key1[:top_n]
    sorted_index_t2_key1 = np.argsort(t2_wm[key1])[::-1]
    t2_words[key1] = sorted_index_t2_key1[:top_n]

    sorted_index_t1_key2 = np.argsort(t1_wm[key2])[::-1]
    t1_words[key2] = sorted_index_t1_key2[:top_n]

    sorted_index_t2_key2 = np.argsort(t2_wm[key2])[::-1]
    t2_words[key2] = sorted_index_t2_key2[:top_n]

    # Get intersection
    # (t11 ∪ t12) ∩ (t21 ∪ t22)
    word_set = (
        set(t1_words[key1]).union(set(t1_words[key2]))
    ).intersection(set(t2_words[key1]).union(set(t2_words[key2])))

    if word_set == set():
        zero_cnt += 1
        return (None, None)  # Intersection is empty, return None

    # Normalize
    t11_prob = t1_wm[key1] / np.sum(t1_wm[key1])
    t12_prob = t1_wm[key2] / np.sum(t1_wm[key2])
    t21_prob = t2_wm[key1] / np.sum(t2_wm[key1])
    t22_prob = t2_wm[key2] / np.sum(t2_wm[key2])

    t11_rank = assign_relative_rankings(t11_prob, word_set)
    t12_rank = assign_relative_rankings(t12_prob, word_set)
    t21_rank = assign_relative_rankings(t21_prob, word_set)
    t22_rank = assign_relative_rankings(t22_prob, word_set)

    t11_rank_word_set = [t11_rank[i] for i in word_set]
    t12_rank_word_set = [t12_rank[i] for i in word_set]
    t21_rank_word_set = [t21_rank[i] for i in word_set]
    t22_rank_word_set = [t22_rank[i] for i in word_set]

    t11_t12_rank_diff = [
        (
            1
            if t11_rank_word_set[i] > t12_rank_word_set[i]
            else -1 if t11_rank_word_set[i] < t12_rank_word_set[i] else 0
        )
        for i in range(len(t11_rank_word_set))
    ]
    t21_t22_rank_diff = [
        (
            1
            if t21_rank_word_set[i] > t22_rank_word_set[i]
            else -1 if t21_rank_word_set[i] < t22_rank_word_set[i] else 0
        )
        for i in range(len(t21_rank_word_set))
    ]

    # Return calculation results
    return (key1, key2), (t11_t12_rank_diff, t21_t22_rank_diff)

def run(
    model,
    num_samples,
    alpha,
    prefix_length,
    filter_threshold,
    combinations,
):
    global t1_wm, t2_wm
    # Batch processing

    for combo in combinations:
        temp, topp, topk = combo["temperature"], combo["topp"], combo["topk"]
        # json_file_name = f"{json_file_paths[idx]}-{model_name}-temp-{temperature}-{keylen}-topp-{top_p}-topk-{top_k}-{samples}.json"
        
        cossim_results = []

        for iter in range(3):
            print("Iteration: ", iter)
            # json_file_name = f"{json_file_paths[idx]}-{model_name}-temp-{temperature}-topp-{top_p}-topk-{top_k}-alpha-{alpha}-prefixlen-{prefix_length}-{samples}-{len(fill_parts)}-iter-{sample_iter}.json"
            p1_file = f"../../data/results/prob2/dip-p1-{model}-temp-{temp}-topp-{topp}-topk-{topk}-alpha-{alpha}-prefixlen-{prefix_length}-{num_samples}-prob2-iter-{iter}.json"
            p2_file = f"../../data/results/prob2/dip-p2-{model}-temp-{temp}-topp-{topp}-topk-{topk}-alpha-{alpha}-prefixlen-{prefix_length}-{num_samples}-prob2-iter-{iter}.json"
            
            print("loading samples...")
            # Load JSON data
            with open(p1_file, "r") as f:
                data_p1 = json.load(f)
            with open(p2_file, "r") as f:
                data_p2 = json.load(f)

            print("processing samples...")
            # Process watermark data
            t1_wm = {
                key: np.array(data_p1["watermarked"][key]["S_wm"])
                for key in data_p1["watermarked"].keys()
            }
            
            t2_wm = {
                key: np.array(data_p2["watermarked"][key]["S_wm"])
                for key in data_p2["watermarked"].keys()
            }

            # Filter
            t1_wm = {
                key: t1_wm[key]
                for key in t1_wm.keys()
                if np.sum(t1_wm[key])>= filter_threshold
            }
            t2_wm = {
                key: t2_wm[key]
                for key in t2_wm.keys()
                if np.sum(t2_wm[key])>= filter_threshold
            }

            wm_common_keys = set(t1_wm.keys()).intersection(set(t2_wm.keys()))
            t1_wm = {key: t1_wm[key] for key in wm_common_keys}
            t2_wm = {key: t2_wm[key] for key in wm_common_keys}
            wm_common_pairs = list(itertools.combinations(wm_common_keys, 2))

            print(f"Lenght of common pairs: {len(wm_common_pairs)}")
            t12_wm_diff = {}

            print("calculating differences...")
            
            with mps.Pool(processes=28) as pool:
                process_func = partial(process_pair)
                
                # Parallel processing using map
                results = pool.map(process_func, wm_common_pairs)
            
            # Process results
            for result in results:
                if result is not None:
                    key_pair, diff_data = result
                    t12_wm_diff[key_pair] = diff_data
            
            # Calculate metrics
            avg_cossim_wm = avg_cossim_tuple(t12_wm_diff)
            cossim_results.append(avg_cossim_wm)
            
            print(f"Average cosine similarity: {avg_cossim_wm}")
        # Get average cosine similarity
        cossim_results = np.array(cossim_results)
        cossim_mean = np.mean(cossim_results)
        
        # Get standard deviation
        cossim_std = np.std(cossim_results, ddof=1)
        
        # Get z-score for avg cossim = 0.1
        z_score = (cossim_mean - 0.1) / (cossim_std)
        z_mpf = mp.mpf(z_score)
        # Calculate one-sided p-value
        p_value = mp.erfc(z_mpf / mp.sqrt(2)) / 2
        p_value_str = mp.nstr(p_value, 50, strip_zeros=False)
        # Save results
        
        # CSV file header
        csv_header = [
            "model_name",
            "temperature",
            "topk",
            "topp",
            "pair_count",
            "cossim_wm",
            "avg_cossim_wm",
            "std_cossim_wm",
            "z_score_wm",
            "p_value_wm",
        ]
        
        output_path = f"{prefix_path}/DIP-{model}-{temp}-{num_samples}-{prefix_length}-{alpha}-prob2-{filter_threshold}.csv"
        os.makedirs(prefix_path, exist_ok=True)
        with open(output_path, "w") as f:
            writer = csv.writer(f)
            writer.writerow(csv_header)
            
            for idx in range(3):
                writer.writerow(
                    [
                        model,
                        temp,
                        topk,
                        topp,
                        len(wm_common_pairs),
                        cossim_results[idx],
                        cossim_mean,
                        cossim_std,
                        z_score,
                        p_value_str,
                    ]
                )

        


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Run script with parameters")
    parser.add_argument("--model_name", type=str, help="Model name list")
    parser.add_argument("--samples", type=int, help="Parameter combination list")
    parser.add_argument("--threshold", type=float, help="Parameter combination list")
    parser.add_argument("--alpha", type=float, help="Parameter combination list")
    parser.add_argument("--prefix_length", type=int, help="Parameter combination list")
    parser.add_argument("--combinations", type=str, help="Parameter combination list")

    args = parser.parse_args()
    print(args)
    model = args.model_name
    
    if args.combinations == "temp":
        combinations = [
            {"temperature": 1.5, "topp": 1.0, "topk": 0},
            {"temperature": 1.4, "topp": 1.0, "topk": 0},
            {"temperature": 1.3, "topp": 1.0, "topk": 0},
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.1, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
            {"temperature": 0.7, "topp": 1.0, "topk": 0},
            {"temperature": 0.6, "topp": 1.0, "topk": 0},
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.4, "topp": 1.0, "topk": 0},
            {"temperature": 0.9, "topp": 1.0, "topk": 0},
            {"temperature": 0.5, "topp": 1.0, "topk": 0},
            {"temperature": 0.4, "topp": 1.0, "topk": 0},
            {"temperature": 0.3, "topp": 1.0, "topk": 0},
            {"temperature": 0.2, "topp": 1.0, "topk": 0},
            {"temperature": 0.1, "topp": 1.0, "topk": 0},
        ]
    elif args.combinations == "experiment":
        combinations = [
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
        ]

    run(
        model=model,
        num_samples=args.samples,
        alpha=args.alpha,
        prefix_length=args.prefix_length,
        filter_threshold=args.threshold,
        combinations=combinations,
    )
