import argparse
import os
import numpy as np
import random

def sigma_lambda(args):
    
    sigma=1.0
    lambda1=0.05
    lambda2=0.0
    if args.method=='linear' and args.prior_type=='exist':
        lambda1=0.05
        add_lambda1=args.proportion*(args.confidence-0.5)/2
        lambda1+=add_lambda1
    if args.method=='nonlinear':
        lambda1=0.01
        lambda2=0.005

    return sigma,lambda1,lambda2



def generate_prior_quasi(args,dag_true,seed=0):

    def check_cycle(vi, vj, dag):
        # whether adding or orientating edge vi->vj would cause cycle. In other words, this function check whether there is a direct path from vj to vi except the possible edge vi<-vj
        # If there is a cycle, return True, else return False
        # dag: adjacency matrix
        n = len(dag)
        cyc_flag = False
        visited = [False] * n
        stack = [vi]
        while len(stack) > 0:
            vi = stack.pop()
            if visited[vi]:
                continue
            visited[vi] = True
            if vi == vj:
                cyc_flag = True
                break
            for vk in range(n):
                if dag[vk, vi] == 1:
                    stack.append(vk)

        return cyc_flag

    def path_generation(dag_true):
        n = dag_true.shape[0]

        reachability = dag_true.copy()
        
        for k in range(n):
            for i in range(n):
                for j in range(n):
                    if reachability[i, k] and reachability[k, j]:
                        reachability[i, j] = 1
        return reachability

    def list2dag(A: list,n=None):
        if n is None:
            n=max([max(x) for x in A])+1
        dag=np.zeros([n,n])
        for par, var in A:
            dag[par,var]=1
        return dag

    random.seed(seed)
    np.random.seed(seed)

    edge_num = len(list(np.argwhere(dag_true!=0)))
    
    if 'order' in args.alg:
        dag_true = path_generation(dag_true)

    true_edges = list(np.argwhere(dag_true!=0))    
    true_edges = np.random.permutation(true_edges).tolist()
    absence_edges = list(np.argwhere(dag_true == 0))
    absence_edges = np.random.permutation(absence_edges).tolist()
    
    edge_existence=[]
    error_prior=[]
    if type(args.error_prior_proportion)==float:
        wrong=edge_num*args.error_prior_proportion
    else:
        wrong=args.error_prior_proportion
    while wrong>0:
        if len(absence_edges)==0:
            print("Too few desirable edges, sampling process stopped.")
            break
        i,j=absence_edges.pop()
        #The wrong edge needs to meet three conditions: it cannot refer to itself, it cannot violate known edges, and it cannot form a ring
        if i==j or (j,i) in edge_existence+error_prior or check_cycle(i,j,list2dag(edge_existence+error_prior,args.n_nodes)):
            continue
        flag=0
        if args.error_prior_type=='all':
            flag=1
        elif args.error_prior_type=='reverse_direct' and dag_true[j,i]==1:
            flag=1
        elif args.error_prior_type=='reverse_ancestor' and check_cycle(i,j,dag_true):
            flag=1
        elif args.error_prior_type=='ancestor' and check_cycle(j,i,dag_true):
            flag=1
        elif args.error_prior_type=='reverse_indirect' and (dag_true[j,i]==1 or check_cycle(i,j,dag_true)):
            flag=1
        elif args.error_prior_type=='irrelevant':
            if dag_true[j,i]!=1 and (not check_cycle(i,j,dag_true)) and (not check_cycle(j,i,dag_true)):
                flag=1
        if flag:
            wrong-=1
            error_prior.append([i,j])
    if type(args.proportion)==float:
        right = int(edge_num*args.proportion)
    else:
        right = args.proportion
    while right>0:
        if len(true_edges)==0:
            break
        i,j=true_edges.pop()
        if (j,i) in edge_existence+error_prior or check_cycle(i,j,list2dag(edge_existence+error_prior,args.n_nodes)):
            continue
        right-=1
        edge_existence.append([i,j])
        
    input_prior=edge_existence+error_prior
    w_prior=list2dag(input_prior,args.n_nodes)
    return w_prior,edge_existence,error_prior