import os
import torch
import tqdm
import math
from matplotlib import pyplot as plt
from scipy.stats import norm
import hashlib
import json
import argparse

def get_red_green_list(vocab_size,device):
    #hyperparameters
    private_key=42
    gamma=0.5
    
    greenlist_size = int(vocab_size * gamma)
    rng=torch.Generator(device=device)
    rng.manual_seed(private_key)
    
    vocab_permutation = torch.randperm(vocab_size, device=device, generator=rng)
    green_list_ids = vocab_permutation[:greenlist_size]
    red_list_ids=vocab_permutation[greenlist_size:]
    return green_list_ids,red_list_ids

def detect_markov_chain_universal(save_path,fpr_thres_list,k):
    assert 1<=k
    assert k<=500
    
    # [n,m,s] n: seq_len; m: pattern_k occur num; s: pattern_s currently in the last
    # here we only care about s<=k-1 for the induction process
    fpr_table=torch.zeros([500,500,500],dtype=torch.float64).cuda()
    final_fpr_table=torch.zeros([500,500],dtype=torch.float64).cuda()
    # initial states
    fpr_table[1,0,0]=1
    
    # pattern_k involves at least k+1 points; e.g.: rg is pattern_1; rgr is pattern_2, rgrg is pattern_3
    if k>1:
        for n in tqdm.tqdm(range(2,500)):
            # pattern_k involves k+1 points, so  k+1+m-1<=n, m<=n-k
            if n-k<=0: # m is always 0
                # s==0
                fpr_table[n,0,0]=0.5*torch.sum(fpr_table[n-1,0,:])
                
                # 0<s<=k-1
                fpr_table[n,0,1:k]=0.5*fpr_table[n-1,0,:k-1]
            else: # m could be 0
                for m in range(n-k+1):
                    if m==0:
                        # s==0
                        fpr_table[n,0,0]=0.5*torch.sum(fpr_table[n-1,0,:])
                        
                        # 0<s<=k-1
                        fpr_table[n,0,1:k]=0.5*fpr_table[n-1,0,:k-1]
                    elif m==n-k:
                        #s==k-1 only one situation
                        fpr_table[n,m,k-1]=0.5*fpr_table[n-1,m-1,k-1]
                    else:
                        #s==0
                        fpr_table[n,m,0]=0.5*torch.sum(fpr_table[n-1,m,:])
                        #s==k-1
                        fpr_table[n,m,k-1]=0.5*fpr_table[n-1,m,k-2]+0.5*fpr_table[n-1,m-1,k-1]
                        # 0<s<k-1
                        if k>2:
                            fpr_table[n,m,1:k-1]=0.5*fpr_table[n-1,m,:k-2]
        
        final_fpr_table=torch.sum(fpr_table,dim=-1)
    else: #k==1 simple
        for n in range(1,500):
            for m in range(n):
                final_fpr_table[n,m]=math.comb(n-1,m)/2**(n-1)
    
    
    def get_fpr_w_fpr_table(key_seq):
        n=len(key_seq)
        m=0 #pattern_k num
        
        cur_k=0
        for idx in range(1,n):
            if key_seq[idx]!=key_seq[idx-1]:
                cur_k+=1
                if cur_k>=k:
                    m+=1
            else:
                cur_k=0
        
        return torch.sum(final_fpr_table[n,m:])

    
    def extract_fpr(save_path):
        fpr_list=[]
        total_hypos=torch.load(save_path)
        for hypo in tqdm.tqdm(total_hypos):
            hypo=hypo.cuda()
            cur_len=hypo.shape[0]
            
            recovered_key_seq=[]
            for idx in range(cur_len):
                cur_green_list,cur_red_list=get_red_green_list(34984,hypo.device)
                
                if hypo[idx] in cur_green_list:
                    recovered_key_seq.append(1)
                else:
                    recovered_key_seq.append(0)

            fpr_list.append(get_fpr_w_fpr_table(recovered_key_seq))
        return fpr_list
    
    
    fpr_list=extract_fpr(save_path)
    
    res_tpr_list=[]
    for fpr_thres in fpr_thres_list:
        res_list=[]
        for idx in range(len(fpr_list)):
            if fpr_list[idx]<=fpr_thres:
                res_list.append(1)
            else:
                res_list.append(0)
        tpr=sum(res_list)/len(res_list)
        res_tpr_list.append(tpr)
    return res_tpr_list

    
if __name__=='__main__':
    pass