import ujson as json
import numpy as np
import itertools
import math
import csv
import multiprocessing as mp
from functools import partial
import math
import os

prefix_path = "../../data/results/csv/KTH-prob2-5gram-results"

# Function to calculate cosine similarity
def cosine_similarity(vec_a, vec_b):
    dot_product = np.dot(vec_a, vec_b)
    norm_a = np.linalg.norm(vec_a)
    norm_b = np.linalg.norm(vec_b)
    if norm_a == 0 and norm_b == 0:
        return 1
    elif norm_a == 0 or norm_b == 0:
        return 0
    return dot_product / (norm_a * norm_b)


def avg_cossim_tuple(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    return np.mean(similarities)


def variance_sample(t12_diff):
    similarities = [cosine_similarity(v1, v2) for v1, v2 in t12_diff.values()]
    if len(similarities) == 0:
        return 0
    return np.var(similarities, ddof=1)


def rank_invariant_transform(lst):
    sorted_indices = sorted(range(len(lst)), key=lambda i: lst[i], reverse=True)
    half_size = len(lst) // 2
    transformed = [0] * len(lst)
    for i in sorted_indices[:half_size]:
        transformed[i] = 1
    for i in sorted_indices[half_size:]:
        transformed[i] = -1
    return transformed

def assign_relative_rankings(probabilities, word_set):
    # Extract the probabilities corresponding to the word_set
    subset_probs = np.array([probabilities[idx] for idx in word_set])
    
    # Get the indices sorted in descending order
    sorted_indices = np.argsort(-subset_probs)  # -subset_probs means descending order
    
    # Create an array of the same length as subset_probs to store rankings
    rankings = np.zeros_like(subset_probs, dtype=int)
    
    # Assign rankings based on the sorted indices
    for rank, idx in enumerate(sorted_indices):
        rankings[idx] = rank + 1  # Rank starts from 1
    
    # Build an array of the same shape as the original probabilities to store the rankings corresponding to word_set
    result = np.zeros_like(probabilities, dtype=int)
    for i, idx in enumerate(word_set):
        result[idx] = rankings[i]
    
    return result

t1_wm = None
t2_wm = None
# Define the function to process a single p air
def process_pair(value):
    
    key1 = value[0]
    key2 = value[1]

    top_n = 50
    t1_words = {}
    t2_words = {}

    # Sort the pair and get the top_n
    sorted_index_t1_key1 = np.argsort(t1_wm[key1])[::-1]
    t1_words[key1] = sorted_index_t1_key1[:top_n]
    sorted_index_t2_key1 = np.argsort(t2_wm[key1])[::-1]
    t2_words[key1] = sorted_index_t2_key1[:top_n]

    sorted_index_t1_key2 = np.argsort(t1_wm[key2])[::-1]
    t1_words[key2] = sorted_index_t1_key2[:top_n]

    sorted_index_t2_key2 = np.argsort(t2_wm[key2])[::-1]
    t2_words[key2] = sorted_index_t2_key2[:top_n]

    # Get the intersection
    # (t11 ∪ t12) ∩ (t21 ∪ t22)
    word_set = (
        set(t1_words[key1]).union(set(t1_words[key2]))
    ).intersection(set(t2_words[key1]).union(set(t2_words[key2])))

    if word_set == set():
        zero_cnt += 1
        return (None, None)  # Intersection is empty, return None

    # Normalize
    t11_prob = t1_wm[key1] / np.sum(t1_wm[key1])
    t12_prob = t1_wm[key2] / np.sum(t1_wm[key2])
    t21_prob = t2_wm[key1] / np.sum(t2_wm[key1])
    t22_prob = t2_wm[key2] / np.sum(t2_wm[key2])

    t11_rank = assign_relative_rankings(t11_prob, word_set)
    t12_rank = assign_relative_rankings(t12_prob, word_set)
    t21_rank = assign_relative_rankings(t21_prob, word_set)
    t22_rank = assign_relative_rankings(t22_prob, word_set)

    t11_rank_word_set = [t11_rank[i] for i in word_set]
    t12_rank_word_set = [t12_rank[i] for i in word_set]
    t21_rank_word_set = [t21_rank[i] for i in word_set]
    t22_rank_word_set = [t22_rank[i] for i in word_set]

    t11_t12_rank_diff = [
        (
            1
            if t11_rank_word_set[i] > t12_rank_word_set[i]
            else -1 if t11_rank_word_set[i] < t12_rank_word_set[i] else 0
        )
        for i in range(len(t11_rank_word_set))
    ]
    t21_t22_rank_diff = [
        (
            1
            if t21_rank_word_set[i] > t22_rank_word_set[i]
            else -1 if t21_rank_word_set[i] < t22_rank_word_set[i] else 0
        )
        for i in range(len(t21_rank_word_set))
    ]

    # Return the calculated results
    return (key1, key2), (t11_t12_rank_diff, t21_t22_rank_diff)

def run(
    model,
    num_samples,
    keylen,
    filter_threshold,
    combinations,
):
    global t1_wm, t2_wm
    # Batch processing

    for combo in combinations:
        temp, topp, topk = combo["temperature"], combo["topp"], combo["topk"]
        # json_file_name = f"{json_file_paths[idx]}-{model_name}-temp-{temperature}-{keylen}-topp-{top_p}-topk-{top_k}-{samples}.json"
        
        cossim_results = []

        for iter in range(3):
            print("Iteration: ", iter)
            p1_file = f"../../data/results/prob2/kth-5gram-p1-{model}-temp-{temp}-{keylen}-topp-{topp}-topk-{topk}-{num_samples}-iter-{iter}.json"
            p2_file = f"../../data/results/prob2/kth-5gram-p2-{model}-temp-{temp}-{keylen}-topp-{topp}-topk-{topk}-{num_samples}-iter-{iter}.json"
            
            
            print("loading samples...")
            # Load JSON data
            with open(p1_file, "r") as f:
                data_p1 = json.load(f)
            with open(p2_file, "r") as f:
                data_p2 = json.load(f)

            print("processing samples...")
            # Process watermark data
            t1_wm = {
                key: np.array(data_p1["watermarked"][key]["S_wm"])
                for key in data_p1["watermarked"].keys()
            }
            
            t2_wm = {
                key: np.array(data_p2["watermarked"][key]["S_wm"])
                for key in data_p2["watermarked"].keys()
            }

            # Filter
            t1_wm = {
                key: t1_wm[key]
                for key in t1_wm.keys()
                if np.sum(t1_wm[key]) >= filter_threshold
            }
            t2_wm = {
                key: t2_wm[key]
                for key in t2_wm.keys()
                if np.sum(t2_wm[key]) >= filter_threshold
            }

            wm_common_keys = set(t1_wm.keys()).intersection(set(t2_wm.keys()))
            t1_wm = {key: t1_wm[key] for key in wm_common_keys}
            t2_wm = {key: t2_wm[key] for key in wm_common_keys}
            wm_common_pairs = list(itertools.combinations(wm_common_keys, 2))

            print(f"Lenght of common pairs: {len(wm_common_pairs)}")
            t12_wm_diff = {}

            print("calculating differences...")
            
            with mp.Pool(processes=28) as pool:
                process_func = partial(process_pair)
                
                # Use map to process in parallel
                results = pool.map(process_func, wm_common_pairs)

            # Process the results
            for result in results:
                if result is not None:
                    key_pair, diff_data = result
                    t12_wm_diff[key_pair] = diff_data
            
            # Calculate metrics
            avg_cossim_wm = avg_cossim_tuple(t12_wm_diff)
            cossim_results.append(avg_cossim_wm)
            
            print(f"Average cosine similarity: {avg_cossim_wm}")
        # Get average cosine similarity
        cossim_results = np.array(cossim_results)
        cossim_mean = np.mean(cossim_results)
        
        # Get standard deviation
        cossim_std = np.std(cossim_results, ddof=1)
        
        # Get z-score for avg cossim = 0.1
        z_score = (cossim_mean - 0.1) / (cossim_std)
        
        # Save results
        # CSV file header
        csv_header = [
            "model_name",
            "temperature",
            "topk",
            "topp",
            "pair_count",
            "cossim_wm",
            "avg_cossim_wm",
            "std_cossim_wm",
            "z_score_wm",
        ]
        output_path = f"{prefix_path}/KTH-5gram-{model}-{temp}-{num_samples}-{keylen}-{filter_threshold}.csv"
        os.makedirs(prefix_path, exist_ok=True)
        
        with open(output_path, "w") as f:
            writer = csv.writer(f)
            writer.writerow(csv_header)
            
            for idx in range(3):
                writer.writerow(
                    [
                        model,
                        temp,
                        topk,
                        topp,
                        len(wm_common_pairs),
                        cossim_results[idx],
                        cossim_mean,
                        cossim_std,
                        z_score,
                    ]
                )

        


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Run the script with parameters")
    parser.add_argument("--model_name", default="llama-2-7b-hf", type=str, help="Model name list")
    parser.add_argument("--samples", type=int, help="Parameter combination list")
    parser.add_argument("--threshold", type=float, help="Parameter combination list")
    parser.add_argument("--keylen", type=int, help="Parameter combination list")
    parser.add_argument("--combinations", default="experiment", type=str, help="Parameter combination list")
    
    args = parser.parse_args()
    print(args)
    model = args.model_name

    if args.combinations == "temp":
        combinations = [
            {"temperature": 1.5, "topp": 1.0, "topk": 0},
            {"temperature": 1.4, "topp": 1.0, "topk": 0},
            {"temperature": 1.3, "topp": 1.0, "topk": 0},
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.1, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
            {"temperature": 0.7, "topp": 1.0, "topk": 0},
            {"temperature": 0.6, "topp": 1.0, "topk": 0},
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.4, "topp": 1.0, "topk": 0},
            {"temperature": 0.9, "topp": 1.0, "topk": 0},
            {"temperature": 0.5, "topp": 1.0, "topk": 0},
            {"temperature": 0.4, "topp": 1.0, "topk": 0},
            {"temperature": 0.3, "topp": 1.0, "topk": 0},
            {"temperature": 0.2, "topp": 1.0, "topk": 0},
            {"temperature": 0.1, "topp": 1.0, "topk": 0},
        ]
    elif args.combinations == "experiment":
        combinations = [
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
        ]

    
    run(
        model=model,
        num_samples=args.samples,
        keylen=args.keylen,
        filter_threshold=args.threshold,
        combinations=combinations,
    )
