from  collections import defaultdict
import numpy as np
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import heapq
import pandas as pd
from src import logger
import scipy
import math
from functools import cache

@cache
def compute_alpha(q, v, r):
    alpha = 0
    sum = 0 # = P[v > Binomial(r, q) >= v - i]
    for i in range(1, v + 1):
        sum = sum + scipy.stats.binom.pmf(v - i, r, q)
        if sum > i * alpha:
            alpha = sum / i
    return alpha
    
# m = number of examples, each included independently with probability 0.5
# r = number of guesses (i.e. excluding abstentions)
# v = number of correct guesses by auditor
# eps,delta = DP guarantee of null hypothesis
# output: p-value = probability of >=v correct guesses under null hypothesis
def p_value_DP_audit(m, r, v, eps, delta, cardinality, topk):
    assert 0 <= v <= r <= m
    assert eps >= 0
    assert 0 <= delta <= 1
    q = topk/(1+(cardinality-1)*math.exp(-eps)) # accuracy of eps-DP randomized response
    beta = scipy.stats.binom.sf(v-1, r, q) # = P[Binomial(r, q) >= v]
    if delta > 0:
        alpha = compute_alpha(q, v, r)
        p = beta + alpha * delta * cardinality * m
    else:
        p = beta
    return min(p, 1)

# m = number of examples, each included with probability 1/cardinality
# r = number of guesses (i.e. excluding abstentions)
# v = number of correct guesses by auditor
# p = 1-confidence e.g. p=0.05 corresponds to 95%
# output: lower bound on eps i.e. algorithm is not (eps,delta)-DP
def get_eps_audit(m, r, v, delta, p, cardinality, topk, eps_max):
    assert 0 <= v <= r <= m
    assert 0 <= delta <= 1
    assert 0 < p < 1
    eps_min = 0 # maintain p_value_DP(eps_min) < p
    # while p_value_DP_audit(m, r, v, eps_max, delta) < p: eps_max = eps_max + 1
    for _ in range(50): # binary search
        eps = (eps_min + eps_max) / 2
        if p_value_DP_audit(m, r, v, eps, delta, cardinality, topk) < p:
            eps_min = eps
        else:
            eps_max = eps
    return eps_min

def top_k_indices(lst, k):
    return [i for i, _ in heapq.nlargest(k, enumerate(lst), key=lambda x: x[1])]

def bottom_k_indices(lst, k):
    return [i for i, _ in heapq.nsmallest(k, enumerate(lst), key=lambda x: x[1])]


def plot_rankings(
    canary_ranking, 
    title, 
    bins, 
    save_path=None
):
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    axes[0].bar(range(len(canary_ranking)), np.sort(canary_ranking), label='Canary Ranking', color='blue')
    axes[0].set_title('Plot of Canary Rankings per index')
    axes[0].set_xlabel('Index')
    axes[0].set_ylabel('Ranking Value')
    axes[0].legend()
    axes[1].hist(canary_ranking, bins=bins, color='green', alpha=0.7, label='Ranking Histogram')
    axes[1].set_title('Histogram of Canary Rankings')
    axes[1].set_xlabel('Ranking Value')
    axes[1].set_ylabel('Frequency')
    axes[1].legend()
    fig.suptitle(f'{title}', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    
    plt.savefig(save_path) if save_path else plt.show()
    plt.close()


    
def mink(v, k):
    scores = []
    for s in v: 
        k_length = int(len(s) * k)
        top_k_tokens = np.sort(s)[:k_length]
        scores.append(np.mean(top_k_tokens))
    return scores

            
def calculate_scores(initial, final, attack = None, **kwargs): 
    if attack == 'minkk':
        initial= [-np.array(sublist) for sublist in initial]
        final= [-np.array(sublist) for sublist in final]
        initial_results, final_results = [], []
        logger.info(f"Calculating min-k attack")
        for i, v in zip(initial, final):
            initial_results.append(mink(i, kwargs.get('k')))
            final_results.append(mink(v, kwargs.get('k')))
        initial_results = np.array(initial_results)
        final_results = np.array(final_results)
        results = final_results - initial_results
        logger.info(f"Results shape: {results.shape}")
        return results
    
    logger.info(f"Calculating loss attack")
    initial = [np.array(sublist) for sublist in initial]
    final= [np.array(sublist) for sublist in final]
    initial = np.array([np.mean(i, axis=1) for i in initial])
    final = np.array([np.mean(i, axis=1) for i in final])
    return initial - final


def get_rankings(scores, f_indices, topk):
    results = defaultdict(list)
    for l, idx in zip(scores, f_indices):
        top = top_k_indices(l, topk)
        bottom = bottom_k_indices(l, topk)
        results['canary_ranking'].append(np.argwhere(np.array(top_k_indices(l, len(l))) == idx))
        results['top'].append((idx in top))
        results['bottom'].append((idx in bottom))
    results['v'] = sum(results['top']) # v = number of correct guesses by auditor
    results['vb'] = sum(results['bottom']) # v = number of correct guesses by auditor
    results['m'] = scores.shape[0] # m = number of examples, each included independently with probability 1/cardinality
    results['r'] = scores.shape[0] # r = number of guesses (i.e. excluding abstentions)
    results['cardinality'] = scores.shape[1]
    results['delta'] = 1e-5
    results['p'] = 0.05
    results['canary_ranking'] = np.array(results['canary_ranking']).flatten().tolist()
    return results


def show_audit(final_scores_path, attack, topk, out_path = None, **kwargs):
    name = final_scores_path.parents[1].name.replace('-only_canaries', "").replace('-steinke', '')
    train_dataset = name.split('-')[0]
    seq_len = name.split('-')[3]
    epsilon = final_scores_path.parents[1].name.split('-')[-3]
    model = final_scores_path.parents[2].name if 'pythia' in str(final_scores_path.parents[2].name) else final_scores_path.parents[3].name
    secret = "sha256_sha512_rsa_private_key" if 'pythia' in str(final_scores_path.parents[2].name) else final_scores_path.parents[2].name  
    lr = str(float(final_scores_path.parents[1].name.split('-')[1]))
    epoch = final_scores_path.parents[1].name.split('-')[2]
    only_canaries = any(["only canaries" in str(k) for k in final_scores_path.parents])
    
 
    initial_scores_path = final_scores_path.parent / 'initial_scores.pkl'
    with open(final_scores_path, 'rb') as f:
        final_scores = pickle.load(f)

    with open(initial_scores_path, 'rb') as f:
        initial_scores = pickle.load(f)
        
    f_indices, i_indices = final_scores['canary_indices'], initial_scores['canary_indices']
    assert f_indices == i_indices, "Indices do not match"
    
    
    scores = calculate_scores(
        initial_scores['scores'] if initial_scores else None, 
        final_scores['scores'], 
        attack=attack,
        **kwargs
        )
    
    ranking_dict = get_rankings(
        scores,
        f_indices, 
        topk=topk
    )
    
    ranking_dict['attack'] = attack
    ranking_dict['topk'] = topk
    ranking_dict['only_canaries'] = only_canaries
    if ranking_dict['topk'] / ranking_dict['cardinality'] > 0.5:
        logger.info(f"Topk: {topk} Cardinality: {ranking_dict['cardinality']}")
        logger.info(f"Topk / Cardinality > 0.5")
        return None
    estimated_epsilon = get_eps_audit(
        m = ranking_dict['m'], 
        r = ranking_dict['r'],
        v = ranking_dict['v'],
        delta = ranking_dict['delta'],
        p = ranking_dict['p'],
        cardinality = ranking_dict['cardinality'], 
        topk=topk, 
        eps_max = float(epsilon)
        )
    attack = f"{attack} {kwargs.get('k', '')}" if attack == 'minkk' else attack
    logger.info(f"Attack {attack}: estimated ε: {estimated_epsilon} real ε: {epsilon}")
    logger.info(f"m: {ranking_dict['m']} r: {ranking_dict['r']} v: {ranking_dict['v']} cardinality: {ranking_dict['cardinality']}")
    logger.info(10*"***")
    out_df = pd.DataFrame([
            {
            "m": ranking_dict['m'],
            "r": ranking_dict['r'],
            "v": ranking_dict['v'],
            "vb": ranking_dict['vb'],
            "delta": ranking_dict['delta'],
            "p": ranking_dict['p'],
            "only_canaries": ranking_dict['only_canaries'],
            "train-dataset": train_dataset, 
            "cardinality": ranking_dict['cardinality'],
            "topk": topk,
            "epsilon": float(epsilon),
            "estimated_epsilon": estimated_epsilon, 
            "model": model,
            "secret": secret,
            "lr": lr,
            "epoch": epoch, 
            "attack": attack,
            "seq len": seq_len

            }
        ]
        )

    out_path = Path(out_path) if out_path else None
    out_path.mkdir(parents=True, exist_ok=True) if out_path else None

    out_df.to_csv(
        out_path / "results.csv", 
        mode='a+', 
        header=not (out_path / "results.csv").exists()
        ) if out_path else None
    
    title = name.split('.')[0]
    
    plot_rankings(
        ranking_dict['canary_ranking'],
        title=title, 
        bins=ranking_dict['cardinality'], 
        save_path=out_path / f"{title}.png"
        ) if out_path else None
        

if __name__ == '__main__':  
    root_path = Path('results')
    files = [k for k in root_path.rglob('final*pkl') if 'steinke' not in str(k).lower() and '410' in str(k)]
    
    out_path = root_path / 'results_final'
    out_path.mkdir(parents=True, exist_ok=True)
    
    min_r, max_r = 1, 10
    for i, f in enumerate(files):
        for k in range(1, 5):
            print(f)
            
            for attack, ratio in zip(
                ['loss'] + ['minkk' for _ in range(min_r, max_r )], [None] + [i/10 for i in range(min_r, max_r)] 
            ):
                show_audit(
                final_scores_path=f,
                attack=attack,
                k = ratio,
                out_path=out_path / "black_box",
                topk=k
            )
