import torch
import os
import tqdm
import math
import math

def load_pdb_res(path,letter2idx):
    with open(path,'r') as f:
        lines=f.readlines()
        
    seq=lines[3].strip()
    
    id_list=[]
    for c in seq:
        id_list.append(letter2idx[c])
        
    return torch.tensor(id_list).cuda()
    
def get_fpr_w_comb(score_list):
    n=len(score_list)
    s=int(sum(score_list))
    
    fpr=0
    for k in range(s,n+1):
        fpr+=math.comb(n,k)*10**k*11**(n-k)
    fpr=fpr/21**n
    return fpr
    

def get_red_green_list(vocab_size,device):
    #hyperparameters
    private_key=42 
    gamma=0.5
    
    greenlist_size = int(vocab_size * gamma)
    rng=torch.Generator(device=device)
    rng.manual_seed(private_key)
    
    
    vocab_permutation = torch.randperm(vocab_size, device=device, generator=rng)
    green_list_ids = vocab_permutation[:greenlist_size]
    red_list_ids=vocab_permutation[greenlist_size:]
    return green_list_ids,red_list_ids
        
########################################################

# calculate fpr using pattern
# pattern is rg or gr 
# pattern_k : the pattern repeats k time. e.g.: rg is pattern_1, rgr is pattern_2, rgrg is pattern_3
# pattern_k is just k 'edges'
def detect_markov_chain_universal(output_dir,fpr_thres,k):
    alphabet='ACDEFGHIKLMNPQRSTVWYX'
    letter2idx={}
    assert 1<=k
    assert k<=500
    
    for idx,c in enumerate(alphabet):
        letter2idx[c]=idx
        
    
    # [n,m,s] n: seq_len; m: pattern_k occur num; s: pattern_s currently in the last
    # here we only care about s<=k-1 for the induction process
    fpr_table=torch.zeros([500,500,500],dtype=torch.float64).cuda()
    final_fpr_table=torch.zeros([500,500],dtype=torch.float64).cuda()
    # initial states
    fpr_table[1,0,0]=1
    
    # pattern_k involves at least k+1 points; e.g.: rg is pattern_1; rgr is pattern_2, rgrg is pattern_3
    if k>1:
        for n in tqdm.tqdm(range(2,500)):
            # pattern_k involves k+1 points, so  k+1+m-1<=n, m<=n-k
            if n-k<=0: # m is always 0
                # s==0
                fpr_table[n,0,0]=0.5*torch.sum(fpr_table[n-1,0,:])
                
                # 0<s<=k-1
                fpr_table[n,0,1:k]=0.5*fpr_table[n-1,0,:k-1]
            else: # m could be 0
                for m in range(n-k+1):
                    if m==0:
                        # s==0
                        fpr_table[n,0,0]=0.5*torch.sum(fpr_table[n-1,0,:])
                        
                        # 0<s<=k-1
                        fpr_table[n,0,1:k]=0.5*fpr_table[n-1,0,:k-1]
                    elif m==n-k:
                        #s==k-1 only one situation
                        fpr_table[n,m,k-1]=0.5*fpr_table[n-1,m-1,k-1]
                    else:
                        #s==0
                        fpr_table[n,m,0]=0.5*torch.sum(fpr_table[n-1,m,:])
                        #s==k-1
                        fpr_table[n,m,k-1]=0.5*fpr_table[n-1,m,k-2]+0.5*fpr_table[n-1,m-1,k-1]
                        # 0<s<k-1
                        if k>2:
                            fpr_table[n,m,1:k-1]=0.5*fpr_table[n-1,m,:k-2]
        
        final_fpr_table=torch.sum(fpr_table,dim=-1)
    else: #k==1 simple
        for n in range(400,500):
            for m in range(n):
                final_fpr_table[n,m]=math.comb(n-1,m)/2**(n-1)
    
    
    def get_fpr_w_fpr_table(key_seq):
        n=len(key_seq)
        m=0 #pattern_k num
        
        cur_k=0
        for idx in range(1,n):
            if key_seq[idx]!=key_seq[idx-1]:
                cur_k+=1
                if cur_k>=k:
                    m+=1
            else:
                cur_k=0
            
        
        return torch.sum(final_fpr_table[n,m:])

    
    def extract_fpr(cur_dir):
        fpr_list=[]
        for filename in tqdm.tqdm(os.listdir(cur_dir)):
            cur_path=os.path.join(cur_dir,filename)
            seq=load_pdb_res(cur_path,letter2idx)
            cur_len=seq.shape[0]
            
            if cur_len>=500 or cur_len<400:
                continue
            
            recovered_key_seq=[]
            for idx in range(cur_len):
                #TODO: change this
                cur_green_list,cur_red_list=get_red_green_list(21,seq.device)
                    
                if seq[idx] in cur_green_list:
                    recovered_key_seq.append(1)
                else:
                    recovered_key_seq.append(0)
                     
            fpr_list.append(get_fpr_w_fpr_table(recovered_key_seq))
        return fpr_list
    
    
    fpr_list=extract_fpr(output_dir)
    res_list=[]
    for idx in range(len(fpr_list)):
        if fpr_list[idx]<=fpr_thres:
            res_list.append(1)
        else:
            res_list.append(0)
    
    tpr=sum(res_list)/len(res_list)
    print('Using markov chain universal with k={}. FPR:{}. Average TPR:{}'.format(k,fpr_thres,tpr))
    return tpr
    

if __name__=='__main__':
    pass
    