import torch,time
from src import notears_prior,notears_adaptive
from preparation import *
from generation_prior import *
from rich import print as rprint
import numpy as np 
from evaluation import evaluation


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_type='gpu'
# device = torch.device("cpu")
# device_type='cpu'
torch.set_printoptions(sci_mode=False)
np.set_printoptions(suppress=True, precision=4)

def main(output_path):
    args={'n_nodes':20,'ER':2,'size':2,'graph_type':'ER','random':0,  
        'method':'linear','sem_type':'gauss','scale':'std',     
        'prior_type':'exist','proportion':0.4,'confidence':0.9,   
        'error_prior_proportion':0.,'error_prior_type':'reverse_direct',   
        'alg':'notears_soft_adaptive','adaptive_degree':0.2,'lambda1':0.1,    
        }    
    args = get_config(args)
    rprint(vars(args))

    # weight_true_dag = DAG.erdos_renyi(n_nodes=args.n_nodes, n_edges=args.n_edg.es, weight_range=(0.5, 2.0), seed=1)
    # dataset = IIDSimulation(W=weight_true_dag, n=args.size*args.n_nodes, method=args.method, sem_type=args.sem_type)
    # true_dag, X = dataset.B, dataset.X
    
    if args.method == 'linear':
        weight_true_dag=np.loadtxt(f'data/W_true/{args.n_nodes}_{args.ER}_{args.graph_type}.csv',delimiter=',')
    elif args.method == 'nonlinear':
        weight_true_dag=np.loadtxt(f'data/W_truenonlinear/{args.n_nodes}_{args.ER}_{args.size}_{args.graph_type}_{args.random}_{args.method}_{args.sem_type[:3]}.csv',delimiter=',')
    true_dag=np.where(weight_true_dag!=0,1,0)
    X=np.loadtxt(f'data/X/{args.n_nodes}_{args.ER}_{args.size}_{args.graph_type}_{args.random}_{args.method}_{args.sem_type}.csv', delimiter=',')
    X=normalize(X,args.scale)

    sigma,lambda1,lambda2= sigma_lambda(args)
    if args.lambda1>0:
        lambda1=args.lambda1

    w_prior,edge_existence,error_prior=generate_prior_quasi(args,true_dag,seed=args.random)
    # print(true_dag)
    # print(w_prior)
    
    if args.alg in ['notears_soft']:
        model = notears_prior(lambda1=lambda1, sigma=sigma, loss_type=args.loss_type,prior_type='soft',device_type=device_type)
    elif args.alg in ['notears_soft_adaptive']:
        model = notears_adaptive(lambda1=lambda1, sigma=sigma, loss_type=args.loss_type,adaptive_degree=args.adaptive_degree,prior_type='soft',device_type=device_type)
    
    if args.alg in ['notears_resilience','dagma_resilience','GOLEM_resilience','notears_soft_adaptive','notears_resilience_order','dagma_adaptive','GOLEM_adaptive','notears_logll_adaptive','notears_mlp_soft_adaptive','notears_mlp_soft_logll_adaptive']:
        model.load_prior(w_prior,args.confidence,weight_true_dag,right=edge_existence,error=error_prior,lambda1=lambda1,lambda2=lambda2,sigma=sigma,args=args,output_path=output_path)
    else:
        model.load_prior(w_prior,args.confidence)

    time1=time.time()
    model.learn(X)
    time2=time.time()
    evaluation(model,true_dag,weight_true_dag,time1,time2,lambda1,lambda2,sigma,args,output_path)

if __name__ == '__main__':
    
    main('out/output.csv')
    os._exit(-1)
