import os
import random
import math
import numpy as np
import torch
from statsmodels.stats.weightstats import _zconfint_generic


def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you use multi-GPU


def generate_data(n, N, mean,cov_matrix):
    rng = np.random.default_rng()  
    x_1 = rng.multivariate_normal(mean.reshape(-1), cov=cov_matrix, size=n) # labeled data
    
    rng_2 = np.random.default_rng()  
    x_2 = rng_2.multivariate_normal(mean.reshape(-1), cov=cov_matrix, size=N) # unlabeled data
    return x_1, x_2


def generate_labels(n, N, x_1, x_2, beta, mu, theta, shift, sigma_y2 = 0.1): 
    U_y = np.random.normal(0, sigma_y2, size=n).reshape(-1,1)
    U_f_1 = np.random.normal(shift, sigma_y2, size=n).reshape(-1,1)
    U_f_2 = np.random.normal(shift, sigma_y2, size=N).reshape(-1,1)
    y = x_1 @ beta + (mu.T @ theta) * np.ones_like(x_1@beta) + U_y # ground truth label
    f_1 = y + shift
    f_2 = x_2 @ beta + (mu.T @ theta) * np.ones_like(x_2 @ beta) + U_f_2
    return y, f_1, f_2


def s_cov(X, Y):   # empirical covariance
    X_mean = np.mean(X, axis=0)  
    Y_mean = np.mean(Y, axis=0)   
    cov_matrix = (X - X_mean).T @ (Y - Y_mean) / (X.shape[0] - 1)
    return cov_matrix


#the ground truth theta_t
def ground_truth_theta(cov, mean, beta,mu,initial_theta, n, d, T, gamma):   
    I = np.identity(d)
    true_theta_list = []
    theta_prev=initial_theta.reshape(-1,1)
    print(f'theta0 {initial_theta}')
    for t in range(T):
        theta = np.linalg.solve(cov+mean@mean.T+(gamma/n)*I, (cov+mean@mean.T)@beta+mean@mu.T@theta_prev)
        theta_prev = theta
        true_theta_list.append(theta)
    return theta, true_theta_list


def compute_theta_PS(cov,mean, beta,mu,n,d,gamma):
    I = np.identity(d)
    return np.linalg.solve(cov + mean @ mean.T - mean @ mu.T + (gamma/n)*I, (cov + mean @ mean.T) @ beta)


def min_theta(x_1, x_2, y, f_1, f_2, lam, gamma): # closed form of \hat{theta} when minimize loss function L^lambda
    n,d = x_1.shape[0], x_1.shape[1]
    N = x_2.shape[0]
    I = np.identity(d)
    A = (1 - lam) * (x_1.T @ x_1 + gamma*I) + lam * (n / N) * (x_2.T @ x_2+gamma*I)
    b = (x_1.T @ (y - lam * f_1)) + lam * (n / N) * (x_2.T @ f_2)
    theta = np.linalg.solve(A, b)
    return theta


def ada_lambda_cov(x_1, x_2, y, f_1, f_2, theta_b, n, N,gradient_list,t,T,gamma): # calculate the optimal selection of lambda at time t as in our note
    I = np.identity(x_1.shape[1])
    G=np.linalg.matrix_power(gradient_list[0], T-t-1) # identity matrix for t = T-1
    x = np.concatenate([x_1, x_2], axis=0)
    f= np.concatenate([f_1, f_2], axis=0)
    nab_l=(x_1 @ theta_b - y)*x_1 + gamma*theta_b.T
    nab_lf=(x_1 @ theta_b - f_1)*x_1 + gamma*theta_b.T
    nab_lfs=(x @ theta_b - f)*x + gamma*theta_b.T
    H=(x_1.T@x_1+x_2.T@x_2 + gamma*I)/(n+N)
    H_i= np.linalg.inv(H)
    a= np.trace(G@H_i@( s_cov(nab_l,nab_lf) + s_cov(nab_lf,nab_l) )@H_i@G)
    b= np.trace(G@H_i@np.cov(nab_lfs.T)@H_i@G)
    ada_lam=a/(2*(1+n/N)*b)
    ada_lam = np.clip(ada_lam, 0, 1)
    V_f=ada_lam**2*np.cov(nab_lfs.T)
    V=np.cov((nab_l-ada_lam*nab_lf).T)
    cov=G.T@H_i@((n/N)*V_f+V)@H_i@G
    return cov, ada_lam


def compute_coverage_and_width(all_covariance, all_estimates, theta_t, n, alpha=0.1): # calculate the confidence interval and its coverage of the ground truth theta_t based on the theta and the variance we obtained from the 1000 trials
    num = len(all_estimates)  
    d = len(theta_t)
    coverage_count = 0
    total_width = 0  

    for i in range(num):
        theta_hat = all_estimates[i]  
        Sigma_hat = all_covariance[i] 
        conf_int = _zconfint_generic(
                theta_hat,
                np.sqrt(np.diag(Sigma_hat) / n),
                alpha=alpha/d,  
                alternative='two-sided'
                )
        lower, upper = conf_int
        widths = upper - lower
        total_width += np.mean(widths)
        if np.all((conf_int[0] <= theta_t) & (theta_t <= conf_int[1])):
            coverage_count=coverage_count+1

    avg_coverage = coverage_count / num
    avg_width = total_width / num  
    return avg_coverage, avg_width


def compute_coverage_and_width_PS(all_covariance, all_estimates, beta_max_2_list, max_norm_x_list, theta0, theta_PS, n, mu, T, gamma, alpha=0.1): #calculate the confidence interval and its coverage of the ground truth theta_t based on the theta and the variance we obtained from the 1000 trials
    num = len(all_covariance)  
    d = len(theta_PS)
    coverage_count = 0
    total_width = 0  
    B = np.linalg.norm(theta0 - theta_PS)
    vareps = np.linalg.norm(mu, ord=2)

    for i in range(num):
        theta_hat = all_estimates[i]
        Sigma_hat = all_covariance[i] 

        betaa = max(max_norm_x_list[i], beta_max_2_list[i])

        r = 2*B*((vareps*betaa/gamma)**T)

        conf_int = _zconfint_generic(
                theta_hat,
                np.sqrt(np.diag(Sigma_hat) / n),
                alpha=alpha/d,  
                alternative='two-sided'
                )
        lower, upper = conf_int
        widths = upper - lower
        total_width += np.mean(widths)

        dist_lower = np.maximum(lower - theta_PS, 0)
        dist_upper = np.maximum(theta_PS - upper, 0)
        dist_vec = np.sqrt(dist_lower**2 + dist_upper**2)
        dist_to_box = np.linalg.norm(dist_vec)  
        if dist_to_box <= r:
            coverage_count += 1

    avg_coverage = coverage_count / num
    avg_width = total_width / num  
    return avg_coverage, avg_width