from src import notears_prior
import numpy as np
from rich import print as rprint
from evaluation import evaluation
import time,copy
import torch
import os,logging

class adaptive():
    def __init__(self):
        pass
    def learn(self,X):
        if self.model=='notears_soft':
            model = notears_prior(lambda1=self.lambda1, sigma=self.sigma, loss_type=self.loss_type,prior_type='soft',device_type=self.device_type)
            self.args.alg='notears_soft'
        model.load_prior(self.w_prior,self.prob_prior)

        time1=time.time()
        model.learn(X)
        time2=time.time()
        evaluation(model,self.true_dag,self.weight_true_dag,time1,time2,self.lambda1,self.lambda2,self.sigma,self.args,self.output_path)
        
        W_initial=model.weight_causal_matrix
        adaptive_degree=np.zeros((W_initial.shape[0],W_initial.shape[0]))
        for (par,var) in self.right+self.error:
            adaptive_degree[par,var]+=len(find_triangles_with_edge(W_initial,par,var))*self.adaptive_degree
        adaptive_degree=1-adaptive_degree
        adaptive_degree[adaptive_degree<0]=0

        if self.model=='notears_soft':
            model = notears_prior(lambda1=self.lambda1, sigma=self.sigma, loss_type=self.loss_type,prior_type='soft',device_type=self.device_type)
        
        # model.W_initial=W_initial
        model.load_prior(self.w_prior,self.prob_prior,adaptive_degree=adaptive_degree)
        model.learn(X)
        self.weight_causal_matrix=model.weight_causal_matrix
        self.causal_matrix=model.causal_matrix
        
    
    def load_prior(self,w_prior=None, prob_prior=0,weight_true_dag=None,right=[],error=[],lambda1=0.1,lambda2=0.1,sigma=1,args=None,output_path=None):
        self.w_prior=w_prior
        self.prob_prior=prob_prior
        self.weight_true_dag=weight_true_dag
        self.true_dag=np.where(weight_true_dag!=0,1,0)
        self.right=right
        self.error=error
        self.lambda1=lambda1
        self.lambda2=lambda2
        self.sigma=sigma
        self.args=copy.deepcopy(args)
        self.args.alg='adaptive'
        self.output_path=output_path
        
class notears_adaptive(adaptive):
    def __init__(self, lambda1=0.1,
                 sigma=1.0, 
                 loss_type='l2', 
                 max_iter=100, 
                 h_tol=1e-8, 
                 rho_max=1e+16, 
                 w_threshold=0.3,
                 prior_type=None,
                 regular_type='l1',
                 device_type='gpu',
                 device_ids=0,
                 W_initial=None,
                 sigma_MCP = 0.2,
                 adaptive_degree=0.9):

        super().__init__()

        self.lambda1 = lambda1
        self.sigma = sigma
        self.loss_type = loss_type
        self.max_iter = max_iter
        self.h_tol = h_tol
        self.rho_max = rho_max
        self.w_threshold = w_threshold
        
        self.w_prior=None
        self.prob_prior=0
        self.ground_truth=None
        self.prior_type=prior_type
        self.regular_type = regular_type
        self.device_type = device_type
        self.device_ids = device_ids
        self.W_initial=W_initial
        self.sigma_MCP=sigma_MCP

        if torch.cuda.is_available():
            logging.info('GPU is available.')
        else:
            logging.info('GPU is unavailable.')
            if self.device_type == 'gpu':
                raise ValueError("GPU is unavailable, "
                                 "please set device_type = 'cpu'.")
        if self.device_type == 'gpu':
            if self.device_ids:
                os.environ['CUDA_VISIBLE_DEVICES'] = str(self.device_ids)
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        self.device = device

        self.model='notears_soft'
        self.adaptive_degree=adaptive_degree

def list2dag(A: list,n=None):
    if n is None:
        n=max([max(x) for x in A])+1
    dag=np.zeros([n,n])
    for par, var in A:
        dag[par,var]=1
    return dag

def find_triangles_with_edge(adj_matrix, beta, j):
    if adj_matrix[beta, j] == 0:
        return []
    
    neighbors_of_i = set(np.where(adj_matrix[beta,:] != 0)[0]) | set(np.where(adj_matrix[:,beta] != 0)[0])
    neighbors_of_j = set(np.where(adj_matrix[:,j] > 0)[0])
    common_neighbors = neighbors_of_i.intersection(neighbors_of_j)
    return common_neighbors
