from tqdm import tqdm
import numpy as np


class Base_Solver():
    def __init__(self, nodes, Para_true, task_para, binary_edge, val_times, nodes_stage, evaluate_func):
        self.A_true, self.B_true, self.C_true, self.noise_type = Para_true
        self.ins_mat = np.linalg.inv(np.eye(*self.A_true.shape) - self.A_true)
        self.t0, self.T = task_para
        self.nodes = nodes
        self.bi_edge = binary_edge
        self.nodes_X, self.nodes_Z, self.nodes_Y = nodes_stage
        self.val_times = val_times
        self.evaluate_Y = evaluate_func

    def AUF_prob(self, ):
        sum_y_list = []
        v0 = self.Noise_sampler(len(self.nodes), 1, self.noise_type)[0].reshape(-1, 1)
        for j in tqdm(range(self.val_times)):
            nodes_values = [v0, ]
            # print(nodes_values)
            for t in range(self.t0+self.T):
                v_past = nodes_values[-1] 
                v_current = self.ins_mat.dot(self.B_true.dot(v_past) + self.Noise_sampler(len(self.nodes), 1, self.noise_type)[0].reshape(-1, 1))
                nodes_values.append(v_current)
            cache = nodes_values[-self.T:]
            sum_y = np.sum([v_value[-len(self.nodes_Y):, 0].reshape(-1, 1) for v_value in cache], axis=0)
            sum_y_list.append(sum_y/self.T)
        
        success_count = 0
        for sum_y in sum_y_list:
            success_count += self.evaluate_Y(sum_y)
        return success_count/self.val_times

    def Noise_sampler(self, dim, num, noise_type = "gaussian"):
        if noise_type == "gaussian":
            return np.random.multivariate_normal(np.zeros((dim, )), self.C_true, num) 
    
        elif noise_type == "laplace":
            marginal_var = np.diag(self.C_true)
            scales = np.sqrt(marginal_var / 2)
            return np.random.laplace(loc=0.0, scale=scales, size=(num, dim))
        
        elif noise_type == "uniform":
            z = np.random.uniform(low=-1.0, high=1.0, size=(num, dim))
            L = np.linalg.cholesky(self.C_true)
            x = z @ L.T
            return x

