import ujson as json
import numpy as np
import itertools
import math
import csv
import multiprocessing as mps
from functools import partial
import math
import os
import mpmath as mp

mp.dps = 1000

prefix_path = "../../data/results/csv/unwatermark-prob2-results"

# Function to calculate cosine similarity
def cosine_similarity(vec_a, vec_b):
    dot_product = np.dot(vec_a, vec_b)
    norm_a = np.linalg.norm(vec_a)
    norm_b = np.linalg.norm(vec_b)
    if norm_a == 0 and norm_b == 0:
        return 1
    elif norm_a == 0 or norm_b == 0:
        return 0
    return dot_product / (norm_a * norm_b)


def avg_cossim_tuple(t12_diff):
    # similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    similarities = []
    
    for v1, v2 in t12_diff.values():
        try:
            similarities.append(cosine_similarity(v1, v2))
        except Exception as e:
            print(e)
            print(v1)
            print(v2)
            print(t12_diff)
            return 0
    return np.mean(similarities)


def variance_sample(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    if len(similarities) == 0:
        return 0
    return np.var(similarities, ddof=1)


def rank_invariant_transform(lst):
    sorted_indices = sorted(range(len(lst)), key=lambda i: lst[i], reverse=True)
    half_size = len(lst) // 2
    transformed = [0] * len(lst)
    for i in sorted_indices[:half_size]:
        transformed[i] = 1
    for i in sorted_indices[half_size:]:
        transformed[i] = -1
    return transformed

def assign_relative_rankings(probabilities, word_set):
    # Extract the probabilities for the word_set
    subset_probs = np.array([probabilities[idx] for idx in word_set])
    
    # Get the indices sorted in descending order
    sorted_indices = np.argsort(-subset_probs)  
    
    # Create an array of the same length as subset_probs to store the rankings
    rankings = np.zeros_like(subset_probs, dtype=int)
    
    # Rank the word_set
    for rank, idx in enumerate(sorted_indices):
        rankings[idx] = rank + 1  # Rank starts from 1
    
    # Build an array of the same shape as the original probabilities to store the rankings for word_set
    result = np.zeros_like(probabilities, dtype=int)
    for i, idx in enumerate(word_set):
        result[idx] = rankings[i]
    
    return result

t1_wm = None
t2_wm = None
# Define the function to process a single pair
def process_pair(value):
    
    key1 = value[0]
    key2 = value[1]

    top_n = 50
    t1_words = {}
    t2_words = {}

    # Sort the pair and get the top_n
    sorted_index_t1_key1 = np.argsort(t1_wm[key1])[::-1]
    t1_words[key1] = sorted_index_t1_key1[:top_n]
    sorted_index_t2_key1 = np.argsort(t2_wm[key1])[::-1]
    t2_words[key1] = sorted_index_t2_key1[:top_n]

    sorted_index_t1_key2 = np.argsort(t1_wm[key2])[::-1]
    t1_words[key2] = sorted_index_t1_key2[:top_n]

    sorted_index_t2_key2 = np.argsort(t2_wm[key2])[::-1]
    t2_words[key2] = sorted_index_t2_key2[:top_n]

    # Get the intersection
    # (t11 ∪ t12) ∩ (t21 ∪ t22)
    word_set = (
        set(t1_words[key1]).union(set(t1_words[key2]))
    ).intersection(set(t2_words[key1]).union(set(t2_words[key2])))

    if word_set == set():
        zero_cnt += 1
        return (None, None)  # The intersection is empty, return None

    # Normalize the probabilities
    t11_prob = t1_wm[key1] / np.sum(t1_wm[key1])
    t12_prob = t1_wm[key2] / np.sum(t1_wm[key2])
    t21_prob = t2_wm[key1] / np.sum(t2_wm[key1])
    t22_prob = t2_wm[key2] / np.sum(t2_wm[key2])

    # t11_rank = assign_relative_rankings(t11_prob, word_set)
    # t12_rank = assign_relative_rankings(t12_prob, word_set)
    # t21_rank = assign_relative_rankings(t21_prob, word_set)
    # t22_rank = assign_relative_rankings(t22_prob, word_set)

    # t11_rank_word_set = [t11_rank[i] for i in word_set]
    # t12_rank_word_set = [t12_rank[i] for i in word_set]
    # t21_rank_word_set = [t21_rank[i] for i in word_set]
    # t22_rank_word_set = [t22_rank[i] for i in word_set]

    if diff == "rank":
        t11_rank = assign_relative_rankings(t11_prob, word_set)
        t12_rank = assign_relative_rankings(t12_prob, word_set)
        t21_rank = assign_relative_rankings(t21_prob, word_set)
        t22_rank = assign_relative_rankings(t22_prob, word_set)

        t11_rank_word_set = [t11_rank[i] for i in word_set]
        t12_rank_word_set = [t12_rank[i] for i in word_set]
        t21_rank_word_set = [t21_rank[i] for i in word_set]
        t22_rank_word_set = [t22_rank[i] for i in word_set]
        t11_t12_rank_diff = [
            (
                1
                if t11_rank_word_set[i] > t12_rank_word_set[i]
                else -1 if t11_rank_word_set[i] < t12_rank_word_set[i] else 0
            )
            for i in range(len(t11_rank_word_set))
        ]
        t21_t22_rank_diff = [
            (
                1
                if t21_rank_word_set[i] > t22_rank_word_set[i]
                else -1 if t21_rank_word_set[i] < t22_rank_word_set[i] else 0
            )
            for i in range(len(t21_rank_word_set))
        ]
        # Return the calculation results
        return (key1, key2), (t11_t12_rank_diff, t21_t22_rank_diff)
    elif diff == "sub":
        t11_t12_diff = [
            t11_prob[i] - t12_prob[i]
            for i in word_set
        ]
        t21_t22_diff = [
            t21_prob[i] - t22_prob[i]
            for i in word_set
        ]
        return (key1, key2), (t11_t12_diff, t21_t22_diff)
    else:
        raise ValueError("Invalid diff type")

    

def run(
    model,
    num_samples,
    filter_threshold,
    combinations,
):
    global t1_wm, t2_wm
    # Batch processing

    for combo in combinations:
        temp, topp, topk = combo["temperature"], combo["topp"], combo["topk"]
        # json_file_name = f"{json_file_paths[idx]}-{model_name}-temp-{temperature}-{keylen}-topp-{top_p}-topk-{top_k}-{samples}.json"
        
        cossim_results = []

        for iter in range(3):
            print("Iteration: ", iter)
            p1_file = f"../../data/results/prob2/fixkey-uw-p1-{model}-temp-{temp}-topp-{topp}-topk-{topk}-{num_samples}-iter-{iter}.json"
            p2_file = f"../../data/results/prob2/fixkey-uw-p2-{model}-temp-{temp}-topp-{topp}-topk-{topk}-{num_samples}-iter-{iter}.json"
            
            
            print("loading samples...")
            # Load JSON data
            with open(p1_file, "r") as f:
                data_p1 = json.load(f)
            with open(p2_file, "r") as f:
                data_p2 = json.load(f)

            print("processing samples...")
            # Process watermark data
            t1_wm = {
                key: np.array(data_p1["unwatermarked"][key]["S_uw"])
                for key in data_p1["unwatermarked"].keys()
            }
            
            t2_wm = {
                key: np.array(data_p2["unwatermarked"][key]["S_uw"])
                for key in data_p2["unwatermarked"].keys()
            }

            # Filter
            t1_wm = {
                key: t1_wm[key]
                for key in t1_wm.keys()
                if np.sum(t1_wm[key]) >= filter_threshold
            }
            t2_wm = {
                key: t2_wm[key]
                for key in t2_wm.keys()
                if np.sum(t2_wm[key]) >= filter_threshold
            }

            wm_common_keys = set(t1_wm.keys()).intersection(set(t2_wm.keys()))
            t1_wm = {key: t1_wm[key] for key in wm_common_keys}
            t2_wm = {key: t2_wm[key] for key in wm_common_keys}
            wm_common_pairs = list(itertools.combinations(wm_common_keys, 2))

            print(f"Lenght of common pairs: {len(wm_common_pairs)}")
            t12_wm_diff = {}

            print("calculating differences...")
            
            with mps.Pool(processes=42) as pool:
                process_func = partial(process_pair)
                
                # Use map to process in parallel
                results = pool.map(process_func, wm_common_pairs)

            # Process the results
            for result in results:
                if result is not None:
                    key_pair, diff_data = result
                    t12_wm_diff[key_pair] = diff_data
                # else:
                #     print("None result")
            # print(t12_wm_diff)
            # Calculate metrics
            avg_cossim_wm = avg_cossim_tuple(t12_wm_diff)
            cossim_results.append(avg_cossim_wm)
            
            print(f"Average cosine similarity: {avg_cossim_wm}")
        # Get average cosine similarity
        cossim_results = np.array(cossim_results)
        cossim_mean = np.mean(cossim_results)
        
        # Get standard deviation
        cossim_std = np.std(cossim_results, ddof=1)
        
        # Get z-score for avg cossim = 0.1
        z_score = (cossim_mean - 0.1) / (cossim_std)
        
        # Convert z_score to mpf type
        z_mpf = mp.mpf(z_score)
        # Calculate one-sided p-value
        p_value = mp.erfc(z_mpf / mp.sqrt(2)) / 2
        p_value_str = mp.nstr(p_value, 50, strip_zeros=False)
    
        # Save results
        # CSV file header
        csv_header = [
            "model_name",
            "temperature",
            "topk",
            "topp",
            "pair_count",
            "cossim_uw",
            "avg_cossim_uw",
            "std_cossim_uw",
            "z_score_uw",
            "p_value_uw",
        ]
        output_path = f"{prefix_path}/unwatermark-prob2-{temp}-{model}-{num_samples}-{filter_threshold}.csv"
        os.makedirs(prefix_path, exist_ok=True)
        
        with open(output_path, "w") as f:
            writer = csv.writer(f)
            writer.writerow(csv_header)
            
            for idx in range(3):
                writer.writerow(
                    [
                        model,
                        temp,
                        topk,
                        topp,
                        len(wm_common_pairs),
                        cossim_results[idx],
                        cossim_mean,
                        cossim_std,
                        z_score,
                        p_value_str,
                    ]
                )

        


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Run the script with parameters")
    parser.add_argument("--model_name", type=str, help="Model name")
    parser.add_argument("--samples", type=int, help="Number of samples")
    parser.add_argument("--threshold", type=float, help="Threshold")
    parser.add_argument("--combinations", type=str, help="Parameter combination list")
    parser.add_argument("--diff", type=str, help="Parameter combination list")

    args = parser.parse_args()
    print(args)
    model = args.model_name
    diff = args.diff
    
    if args.combinations == "temp":
        combinations = [
            {"temperature": 1.5, "topp": 1.0, "topk": 0},
            {"temperature": 1.4, "topp": 1.0, "topk": 0},
            {"temperature": 1.3, "topp": 1.0, "topk": 0},
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.1, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.9, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
            {"temperature": 0.7, "topp": 1.0, "topk": 0},
            {"temperature": 0.6, "topp": 1.0, "topk": 0},
            {"temperature": 0.5, "topp": 1.0, "topk": 0},
            {"temperature": 0.4, "topp": 1.0, "topk": 0},
            {"temperature": 0.3, "topp": 1.0, "topk": 0},
            {"temperature": 0.2, "topp": 1.0, "topk": 0},
            {"temperature": 0.1, "topp": 1.0, "topk": 0},
        ]
    elif args.combinations == "experiment":
        combinations = [
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
        ]

    run(
        model=model,
        num_samples=args.samples,
        filter_threshold=args.threshold,
        combinations=combinations,
    )
