import numpy as np
import pickle as pkl
from typing import List
from src.utils import sample_vectorized, dip_reweight, stanford_sampling
from src.attacks.kgw_detection import test_kgw_detection
from src.attacks.stanford_detection import test_stanford
from joblib import Parallel, delayed
from scipy.stats import fisher_exact
import random
from collections import Counter
from src.models import _load_tokenizer


def apply_1test(
    test: str,
    wm_scheme: str,
    model_name: str,
    custom_name: str,
    temperature: float = 1.0,
    seeding_scheme: str = "lefthash",
    gamma: float = 0.25,
    keys: List = [1],
    delta: float = 2.0,
    context: int = 5,
    max_new_tokens: int = 1000,
    n_samples: int = 0,
    key_size: int = 256,
    num_permutations=1000,
    debug: bool = False,
    bayesian: bool = False,
    alpha: float = 0.5,
    sample_bootstrap: int = -1,
    disable_watermark_every: int = 0,
    n_queries: int = None,
):
    if test == "KGW":
        if wm_scheme == "no_watermark":
            path = f"pkl_results/{test}/{model_name}/{wm_scheme}/{custom_name}_context{context}_{temperature}.pkl"
            delta = 0
        elif wm_scheme == "KGW":
            path = f"pkl_results/{test}/{model_name}/{wm_scheme}/{seeding_scheme}/{custom_name}_context{context}_gamma{gamma}_{temperature}_{keys}.pkl"
        elif wm_scheme == "dipmark":
            path = f"pkl_results/{test}/{model_name}/no_watermark/{custom_name}_context{context}_{temperature}.pkl"
            delta = 0
        elif wm_scheme == "stanford":
            path = f"pkl_results/{test}/{model_name}/no_watermark/{custom_name}_context{context}_{temperature}.pkl"
            delta = 0
        elif wm_scheme == "DeltaReweight":
            path = f"pkl_results/{test}/{model_name}/no_watermark/{custom_name}_context{context}_{temperature}.pkl"
            delta = 0

        with open(path, "rb") as f:
            out = pkl.load(f)

        data = out[delta]

        if wm_scheme == "dipmark":
            # We use the same permutation for fixed
            reweighted_data = np.zeros_like(data)
            for t2 in range(data.shape[1]):
                reweighted_data[:, t2, :] = dip_reweight(
                    data[:, t2, :], alpha, same_permutation=True
                )

            data = reweighted_data

        if disable_watermark_every > 0:
            unwatermarked_data_path = f"pkl_results/{test}/{model_name}/no_watermark/{custom_name}_context{context}_{temperature}.pkl"
            with open(unwatermarked_data_path, "rb") as f:
                unwatermark_out = pkl.load(f)
            unwatermark_data = unwatermark_out[0]
        else:
            unwatermark_data = None

        if n_samples != 0:
            if wm_scheme == "stanford":
                data = stanford_sampling(
                    data,
                    n_samples,
                    key_size,
                    bayesian=bayesian,
                    unwatermark_prob=unwatermark_data,
                    disable_watermark_every=disable_watermark_every,
                )
            else:
                data = sample_vectorized(
                    data,
                    n_samples,
                    bayesian=bayesian,
                    unwatermark_prob=unwatermark_data,
                    disable_watermark_every=disable_watermark_every,
                )

        if sample_bootstrap != -1:

            def thread(data, num_permutations, n_samples):
                probs = sample_vectorized(
                    data,
                    int(0.9 * n_samples),
                    stability=True
                )
                _, _, p_value = test_kgw_detection(probs, num_permutations)
                return p_value

            p_values = Parallel(n_jobs=-1)(
                delayed(thread)(data, num_permutations, n_samples)
                for _ in range(sample_bootstrap)
            )
            p_value = np.median(p_values)

        else:
            observed_statistics, statistics, p_value = test_kgw_detection(
                data, num_permutations
            )

        if debug:
            print(observed_statistics, statistics)

    elif test == "stanford":
        if wm_scheme == "no_watermark":
            path = f"pkl_results/{test}/{model_name}/{wm_scheme}/{custom_name}_maxtokens{max_new_tokens}_{temperature}.pkl"
        elif wm_scheme == "KGW":
            path = f"pkl_results/{test}/{model_name}/{wm_scheme}/{seeding_scheme}/{custom_name}_maxtokens{max_new_tokens}_gamma{gamma}_delta{delta}_{temperature}_{keys}.pkl"
        elif wm_scheme == "stanford":
            wm_scheme = "Stanford"
            path = f"pkl_results/{test}/{model_name}/{wm_scheme}/{custom_name}_maxtokens{max_new_tokens}_{temperature}_{keys}_{key_size}.pkl"
        elif wm_scheme == "dipmark":
            wm_scheme = "DiPMark"
            path = f"pkl_results/{test}/{model_name}/{wm_scheme}/{custom_name}_maxtokens{max_new_tokens}_{temperature}_{keys}_alpha{alpha}.pkl"
        elif wm_scheme == "DeltaReweight":
            path = f"pkl_results/{test}/{model_name}/{wm_scheme}/{custom_name}_maxtokens{max_new_tokens}_{temperature}_{keys}.pkl"

        with open(path) as f:
            whole_txt = f.read()
            split = whole_txt.split("###NEW_RESPONSE###")[1:]
            # Create a histogram of the lines. First hash the lines to a number
            answer_dic = {}
            
            tokenizer = _load_tokenizer("meta-llama/Meta-Llama-3-8B-Instruct")

            if n_queries is not None:
                
                if len(split) < n_queries:
                    print(f"Warning: n_queries is larger than the number of queries in the file. n_queries is set to {len(split)}")
                
                # Shuffle split
                random.shuffle(split)
                split = split[:n_queries]
                


            for line in split:
                if False:
                    tokenized_line = tokenizer.encode(line, return_tensors='pt')
                    tokenized_line = tokenized_line[0].tolist()
                    tokenized_line = tokenized_line[1:15]
                    line = tokenizer.decode(tokenized_line)
                    
                if line not in answer_dic:
                    answer_dic[line] = 1
                else:
                    answer_dic[line] += 1

        data = [
            sentence for sentence, count in answer_dic.items() for _ in range(count)
        ]

        p_value = test_stanford(data)

    elif test == "cache":
        if wm_scheme == "no_watermark":
            path = f"pkl_results/{test}/{model_name}/{wm_scheme}/{custom_name}_{temperature}.pkl"
            delta = 0
        elif wm_scheme == "KGW":
            path = f"pkl_results/{test}/{model_name}/{wm_scheme}/{seeding_scheme}/{custom_name}_gamma{gamma}_{temperature}_{keys}.pkl"
        elif wm_scheme == "dipmark":
            path = f"pkl_results/{test}/{model_name}/no_watermark/{custom_name}_{temperature}.pkl"
            delta = 0
        elif wm_scheme == "stanford":
            path = f"pkl_results/{test}/{model_name}/no_watermark/{custom_name}_{temperature}.pkl"
            delta = 0
        elif wm_scheme == "dipmark":
            path = f"pkl_results/{test}/{model_name}/no_watermark/{custom_name}_{temperature}.pkl"
            delta = 0
        elif wm_scheme == "DeltaReweight":
            path = f"pkl_results/{test}/{model_name}/no_watermark/{custom_name}_{temperature}.pkl"
            delta = 0
        else:
            raise ValueError("Unknown watermarking scheme")

        with open(path, "rb") as f:
            out = pkl.load(f)

        data = out[delta]
        data = data / np.sum(data, axis=1)[:, None]  # rejection sampling

        og_probs = np.array(
            [data[0]]
        )  # We only need one t2; if there are multiple t2 it is for alpha estimation
        probs = og_probs
        if wm_scheme == "dipmark":
            og_probs = dip_reweight(og_probs, alpha)
        phase_1_samples = sample_vectorized(
            og_probs, n_samples, return_samples=True, bayesian=bayesian
        )

        if wm_scheme == "cache_dipmark":
            probs = dip_reweight(og_probs, alpha)
        elif wm_scheme == "DeltaReweight":
            choice = np.random.choice([0, 1], p=og_probs[0])
            probs = np.array([[choice, 1 - choice]])

        if disable_watermark_every > 0:
            unwatermarked_data_path = f"pkl_results/{test}/{model_name}/no_watermark/{custom_name}_{temperature}.pkl"
            with open(unwatermarked_data_path, "rb") as f:
                unwatermark_out = pkl.load(f)
            unwatermark_data = unwatermark_out[0]
            unwatermark_data = (
                unwatermark_data / np.sum(unwatermark_data, axis=1)[:, None]
            )  # rejection sampling
        else:
            unwatermark_data = None

        if n_samples != 0:
            if wm_scheme == "stanford":
                samples = stanford_sampling(
                    probs,
                    n_samples,
                    key_size,
                    return_samples=True,
                    bayesian=bayesian,
                    unwatermark_prob=unwatermark_data,
                    disable_watermark_every=disable_watermark_every,
                )
                samples = [samples]
            elif wm_scheme == "KGW":
                samples = sample_vectorized(
                    probs,
                    n_samples,
                    return_samples=True,
                    bayesian=bayesian,
                    disable_watermark_every=0,
                )
            else:
                samples = sample_vectorized(
                    probs,
                    n_samples,
                    return_samples=True,
                    bayesian=bayesian,
                    unwatermark_prob=unwatermark_data,
                    disable_watermark_every=disable_watermark_every,
                )
        else:
            raise ValueError("n_samples should be different than 0 for cache test")

        phase_1_samples = phase_1_samples[0]
        samples = samples[0]

        res = fisher_exact([samples, phase_1_samples])
        p_value = res.pvalue

    return p_value
