import os
from tqdm import tqdm
import math
import numpy as np
import torch
from torch.autograd import grad
import matplotlib.pyplot as plt
from statsmodels.stats.weightstats import _zconfint_generic
from models import ScoreModel, DNNScoreModel
from utils import *


def Gradient(theta_traj,score_traj,x_1,n):
    G=[]
    for t in range(T):
        theta_t = theta_traj[t-1] if t !=0 else theta0
        theta_t_plus = theta_traj[t]
        p_t = score_traj[t]
        H=(x_1.T@x_1+gamma*I)/(n) # hessian of L(z,theta_{t+1}) z\sim \theta_t
        H_i=np.linalg.inv(H)
        y = x_1 @ beta + (mu.T @ theta_t) * np.ones_like(x_1@beta) # + U_y vanish when taking expectation
        nab_l = (x_1 @ theta_t_plus - y)*x_1 + gamma*theta_t_plus.T  # gradient (theta) of \nabla l(x,y,theta_{t+1})       
        G_t = - H_i @ ((nab_l.T @ p_t) / n)         
        G.append(G_t)   
    return G


def variance_general(x_1, x_2, y, f_1, f_2, theta_b, n, N, t, gradient_list, lam): # calculate the (contribution of time step t to the final time T) variance, as in our note
    x = np.concatenate([x_1, x_2], axis=0)
    f= np.concatenate([f_1, f_2], axis=0)
    nab_l=(x_1 @ theta_b - y)*x_1 + gamma*theta_b.T
    nab_lf=(x_1 @ theta_b - f_1)*x_1 + gamma*theta_b.T
    nab_lfs=(x @ theta_b - f)*x + gamma*theta_b.T
    H=(x_1.T@x_1+x_2.T@x_2 + gamma*I)/(n+N)
    H_i= np.linalg.inv(H)

    V_f=lam**2*np.cov(nab_lfs.T)
    V=np.cov((nab_l-lam*nab_lf).T)
    if t < T-1:
        G = gradient_list[t+1]
        for matrix in gradient_list[t+2:T]:
            G = G @ matrix
        cov=G.T@H_i@((n/N)*V_f+V)@H_i@G
    else:
        cov=H_i@((n/N)*V_f+V)@H_i
    return cov


def variance(x_1, x_2, y, f_1, f_2, theta_b, n, N, lam,gradient,t): # calculate the (contribution of time step t to the final time T) variance, as in our note
    G=np.linalg.matrix_power(gradient, T-t-1)
    x = np.concatenate([x_1, x_2], axis=0)
    f= np.concatenate([f_1, f_2], axis=0)
    nab_l=(x_1 @ theta_b - y)*x_1 + gamma*theta_b.T
    nab_lf=(x_1 @ theta_b - f_1)*x_1 + gamma*theta_b.T
    nab_lfs=(x @ theta_b - f)*x + gamma*theta_b.T
    H=(x_1.T@x_1+x_2.T@x_2 + gamma*I)/(n+N)
    H_i= np.linalg.inv(H)
    V_f=lam**2*np.cov(nab_lfs.T)
    V=np.cov((nab_l-lam*nab_lf).T)
    cov=G.T@H_i@((n/N)*V_f+V)@H_i@G
    return cov


def compute_beta_2(x_1, theta_t, y_1):
    beta_i_max = 0.
    for x_i,y_i in zip(x_1,y_1):  
        x_norm = np.linalg.norm(x_i)
        x_i = np.expand_dims(x_i, axis=0)
        beta_i = math.sqrt(((x_i@theta_t).item()-y_i[0]+x_norm*np.linalg.norm(theta_t))**2 + x_norm ** 2)
        if beta_i > beta_i_max:
            beta_i_max = beta_i
    return beta_i_max


def collect_theta(n):
    theta_list = []
    data_list = []
    cov_total = np.zeros((d, d))
    x_1, x_2 = generate_data(n, N, mean, cov_matrix)
    theta_t = theta0
    row_sq_norms = np.sum(x_1**2, axis=1)
    max_norm_x = np.max(row_sq_norms)
    # for computing beta_max_2 in confidence width 
    beta_max_2 = 0.
    for t in range(T):
        y_1, f_1, f_2 = generate_labels(n, N, x_1, x_2, beta, mu, theta_t, shift, sigma_y2)
        theta_t = min_theta(x_1, x_2, y_1, f_1, f_2, 0, gamma)
        data_list.append((x_1, x_2, y_1, f_1, f_2))
        theta_list.append(theta_t.squeeze())
        cov_total += variance(x_1, x_2, y_1, f_1, f_2, theta_t, n, N, 0, gradient, t)
        beta_2 = compute_beta_2(x_1, theta_t, y_1)
        if beta_2 > beta_max_2:
            beta_max_2 = beta_2
    return theta_list, data_list, cov_total, x_1, x_2, beta_max_2, max_norm_x


def compute_variance(lam, theta_list, data_list, gradient_list, batch_size):
    est_variance = np.zeros((d, d))
    for t in range(T):
        theta_t = np.expand_dims(theta_list[t], axis=1)
        cov = variance_general(*data_list[t], theta_t, batch_size, N, t, gradient_list, lam)
        est_variance += cov
    return est_variance


def compute_variance_lam(n, x_1, x_2, gradient_list):
    est_cov_dict = {m: np.zeros((d, d)) for m in methods[1:]}
    true_cov_dict = {m: np.zeros((d, d)) for m in methods[1:]}
    theta_dict = {m: theta0 for m in methods[1:]}
    beta_max_2 = {m: 0. for m in methods[1:]}

    for t in range(T):
        labels = {m: generate_labels(n, N, x_1, x_2, beta, mu, theta_dict[m], shift, sigma_y2) for m in methods[1:]}

        theta_b = min_theta(x_1, x_2, *labels['greedy'], 0, gamma)

        cov_g, lam_g = ada_lambda_cov(x_1, x_2, *labels['greedy'], theta_b, n, N, gradient_list, T - 1, T, gamma)
        theta_dict['greedy'] = min_theta(x_1, x_2, *labels['greedy'], lam_g, gamma)
        true_cov_dict['greedy'] += variance(x_1, x_2, *labels['greedy'], theta_dict['greedy'], n, N, lam_g, gradient, t)
        est_cov_dict['greedy'] += variance_general(x_1, x_2, *labels['greedy'], theta_dict['greedy'], n, N, t, gradient_list, lam_g)
        
        beta_2 = compute_beta_2(x_1, theta_dict['greedy'], labels['greedy'][0])
        if beta_2 > beta_max_2['greedy']:
            beta_max_2['greedy'] = beta_2

        theta_dict['lam1'] = min_theta(x_1, x_2, *labels['lam1'], 1, gamma)
        true_cov_dict['lam1'] += variance(x_1, x_2, *labels['lam1'], theta_dict['lam1'], n, N, 1, gradient, t)
        est_cov_dict['lam1'] += variance_general(x_1, x_2, *labels['lam1'], theta_dict['lam1'], n, N, t, gradient_list, 1)
        
        beta_2 = compute_beta_2(x_1, theta_dict['lam1'], labels['lam1'][0])
        if beta_2 > beta_max_2['lam1']:
            beta_max_2['lam1'] = beta_2

    return theta_dict, est_cov_dict, true_cov_dict, beta_max_2


def generate_y(x,theta_i, n):
    U_y = np.random.normal(0, sigma_y2, size=n).reshape(-1,1)
    y = x @ beta + (mu.T @ theta_i) * np.ones_like(x@beta) + U_y
    return y


def perturb_data(theta_list, data_list, eps, n):
    L = len(theta_list)
    x0, _, y0, _, _ = data_list[0]
    x_shape = x0.shape     # (batch_size, x_dim)
    y_shape = y0.shape 

    theta_pert_arr = np.zeros((L, d, d), dtype=theta0.dtype)
    x_pert_arr = np.zeros((L, d, *x_shape), dtype=x0.dtype)
    y_pert_arr = np.zeros((L, d, *y_shape), dtype=y0.dtype)

    for t, theta_orig in enumerate(theta_list):
        theta_list[t] = theta_orig.squeeze()
        x_t, x_2, y_1, f_1, f_2 = data_list[t]
        for i in range(d):
            theta_pert = theta_orig.copy().squeeze()
            theta_pert[i] += eps
            x_i = x_t.copy()
            y_i = generate_y(x_i, theta_pert, n)
            theta_pert_arr[t,i] = theta_pert
            x_pert_arr[t,i] = x_i
            y_pert_arr[t,i] = y_i
    return theta_pert_arr, x_pert_arr, y_pert_arr


def compute_temp_1_3(score, theta):
    '''
    score: [Batch_size,]
    theta: [Batch_size, d]
    return temp1 = 1/B ∑_{j=1}^B ∑_{i=1}^d ( ∂ score_j / ∂ theta_{j,i} )^2
           temp3 = 1/B ∑_{i=1}^d ∑_{j=1}^B  ∂^2 score_j / ∂ theta_{j,i}^2
    '''
    grads = grad(score.sum(), theta, create_graph=True)[0]  
    temp1 = (grads.pow(2).sum(dim=1)).mean() 
    # compute temp3
    temp3 = 0
    for i in range(theta.size(1)):
        hess_i = grad(grads[:, i].sum(), 
                       theta, create_graph=True)[0][:, i] 
        temp3 += hess_i.mean()

    return grads, temp1, temp3


def compute_temp_2_i(score, theta, i):
    '''
    score: [Batch_size,]
    theta: [Batch_size, d]
    return 1/B ∑_{j=1}^B  ∂ score_j / ∂ theta_{j,i} 
    '''
    grads = torch.autograd.grad(score.sum(), theta, create_graph=True)[0] 
    return grads[:, i].mean()


def log_joint(x, y, theta, Sigma_x, beta, mu, sigma_y):
    """
    Compute log p(x, y, theta)
    """
    n, d = x.shape
    device, dtype = x.device, x.dtype

    
    Sigma_x_beta = Sigma_x @ beta  
    top = torch.cat([Sigma_x, Sigma_x_beta.unsqueeze(1)], dim=1)  
    bottom = torch.cat([Sigma_x_beta.unsqueeze(0), 
                        (beta @ Sigma_x_beta + sigma_y**2).unsqueeze(0).unsqueeze(1)], dim=1)  
    Sigma_z = torch.cat([top, bottom], dim=0)  
    
    sign, logdet = torch.linalg.slogdet(Sigma_z)
    if sign <= 0:
        raise ValueError("Covariance matrix is not positive definite.")
    inv_Sigma_z = torch.inverse(Sigma_z)  

    
    z = torch.cat([x, y], dim=1)  
    mu_theta = torch.sum(mu * theta, dim=-1, keepdim=True) 
    zeros = torch.zeros(n, d, device=device, dtype=dtype)
    mu_z = torch.cat([zeros, mu_theta], dim=1)
    diff = z - mu_z  

    
    diff_inv = diff @ inv_Sigma_z  
    quad = -0.5 * torch.sum(diff_inv * diff, dim=1)  

    const = -0.5 * (d + 1) * torch.log(torch.tensor(2 * torch.pi, device=device, dtype=dtype)) \
            - 0.5 * logdet

    return quad, const


def compute_est_score(model, theta_list, data_list, batch_size, x_1, x_2):
    grad_diff_sum = 0.0
    num_steps = len(theta_list)
    score_list = []
    for t, theta_np in enumerate(theta_list):
        x_t, _, y_t,_, _ = data_list[t]
        x_t = torch.tensor(x_t, dtype=torch.float).to(device)
        y_t = torch.tensor(y_t, dtype=torch.float).to(device)
        theta = torch.tensor(theta_np, dtype=torch.float, requires_grad=True) .to(device)
        theta_batch = theta.unsqueeze(0).repeat(batch_size, 1).to(device)
        # --- Compute ∇θ logP ---
        quad_P, const_P = log_joint(x_t, y_t, theta_batch, Sigma_x, beta_t, mu_t, sigma_y)
        logP = quad_P + const_P
        gradP = torch.autograd.grad(logP.sum(), theta_batch)[0]  

        # --- Compute ∇θ logM ---
        if net == 'DNN':
            logM = model(x_t, y_t, theta_batch)
        else:
            quad_M, const_M = model(x_t, y_t, theta_batch)
            logM = quad_M + const_M
        

        gradM = torch.autograd.grad(logM.sum(), theta_batch)[0]  
        score_list.append(gradM.cpu().detach().numpy())


        # --- Compute squared norm of the difference ---
        diff = gradP - gradM
        grad_diff_sq = (diff.pow(2).sum(dim=1)).mean()  
        grad_diff_sum += grad_diff_sq.item()

    mean_grad_diff_sq = grad_diff_sum / num_steps
    theta_list = [np.expand_dims(x, axis=1) for x in theta_list]
    gradient_list = Gradient(theta_list, score_list,x_1,batch_size)
    
    return gradient_list, mean_grad_diff_sq


def train_score_model(theta_list, data_list, theta_pert_arr, x_pert_arr, y_pert_arr, method='gaussian', epoch=1, lr=1e-3, batch_size=1000):
    net = method 
    model = DNNScoreModel(d, 1, d) if net == 'DNN' else ScoreModel(d, 1, mu)
    model = model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr) 

    for e in range(epoch):
        # print(f"--------epoch {e}---------")
        loss_e, temp1_e, temp2_e, temp3_e = 0, 0, 0, 0
        for t in range(len(theta_list)):
            theta_np = theta_list[t]
            x_t, x_2, y_t, f_1, f_2 = data_list[t]
            # sample data batch from data_list
            x_t = torch.tensor(x_t, dtype=torch.float).to(device)
            y_t = torch.tensor(y_t, dtype=torch.float).to(device)

            theta_t = torch.tensor(theta_np, dtype=torch.float, requires_grad=True).to(device)
            theta_batch = theta_t.unsqueeze(0).repeat(batch_size, 1).to(device)
            
            if net == 'DNN':
                logM = model(x_t, y_t, theta_batch)
            else:
                quad, const = model(x_t, y_t, theta_batch) 
                logM = quad + const
            grads_t, temp1, temp3 = compute_temp_1_3(logM, theta_batch) 

            # compute temp 2
            temp2 = 0
            for i in range(d):
                theta_pert = theta_pert_arr[t][i]
                x_i, y_i = x_pert_arr[t][i], y_pert_arr[t][i]

                theta_i = torch.tensor(theta_pert, dtype=torch.float, requires_grad=True).to(device)
                x_i = torch.tensor(x_i, dtype=torch.float).to(device)
                y_i = torch.tensor(y_i, dtype=torch.float).to(device)
                theta_i_batch = theta_i.unsqueeze(0).repeat(batch_size, 1).to(device)

                if net == 'DNN':
                    logM_i = model(x_i, y_i, theta_i_batch)
                else:
                    quad, const = model(x_i, y_i, theta_i_batch) 
                    logM_i = quad + const
                grads_per = torch.autograd.grad(logM_i.sum(), theta_i_batch, create_graph=True)[0]
                temp2_i = (grads_per[:, i].mean()
                        - grads_t[:, i].mean()) / eps
                temp2 += temp2_i

            # non-loss term1
            quad_P, const_P = log_joint(x_t, y_t, theta_batch, Sigma_x, beta_t, mu_t, sigma_y)
            logP = quad_P + const_P
            gradsP = torch.autograd.grad(logP.sum(), theta_batch, create_graph=True)[0]  
            gradsP = (gradsP.pow(2).sum(dim=1)).mean() 


            loss = temp1 - 2*temp2 + 2*temp3
            loss_e += loss  + gradsP
            temp1_e += temp1
            temp2_e += temp2
            temp3_e += temp3

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # print('epoch {}, avg loss {} term1 {} temp1 {} temp2 {} temp3 {}'.format(e, loss_e / T, gradsP, temp1_e /T, temp2_e /T, temp3_e /T))
        # mean_grad_diff_sq = compute_est_variance(model, theta_list, data_list, batch_size)
        # print(mean_grad_diff_sq)    
        
    return model


def prepare_param(a,b,c,j,sigma_y2):
    cov_matrix = j * cov_matrix0
    mean = a * mean0
    beta = b * beta0
    mu = c * mu0

    gradient = (np.linalg.inv(cov_matrix + mean @ mean.T + gamma * I)) @ mean @ (mu.T)

    Sigma_x = torch.tensor(cov_matrix, dtype=torch.float).squeeze().to(device)
    beta_t = torch.tensor(beta, dtype=torch.float).squeeze().to(device)
    mu_t = torch.tensor(mu, dtype=torch.float).squeeze().to(device)
    sigma_y = torch.tensor(sigma_y2, dtype=torch.float).to(device)
    
    return cov_matrix, mean, beta, mu, gradient, Sigma_x, beta_t, mu_t, sigma_y


def run_exp():
    widths = {m: [] for m in methods}
    coverages = {m: [] for m in methods}

    true_widths = {m: [] for m in methods}
    true_coverages = {m: [] for m in methods}
    
    for n in n_list:
        true_theta_t, true_theta_list = ground_truth_theta(cov_matrix,mean,beta,mu,theta0,n,d,T,gamma)
        theta_PS = compute_theta_PS(cov_matrix,mean, beta,mu,n,d,gamma)
        print('true_theta_t {}'.format(true_theta_t))
        print('theta_PS {}'.format(theta_PS))
        batch_size = n
        true_var_dict = {m: [] for m in methods}
        vars_dict = {m: [] for m in methods}
        thetas_dict = {m: [] for m in methods}
        all_est_J = {m: 0 for m in methods}
        beta_max_2_dict = {m: [] for m in methods}
        max_norm_x_list = []

        print('\trunning experiments with n={}'.format(n))
        for r in tqdm(range(repeat)):
            # first compute for lambda=0
            theta_lam0_list, data_lam0_list, cov_total, x_1, x_2, beta_max_2, max_norm_x = collect_theta(n)
            
            theta_pert_arr, x_pert_arr, y_pert_arr = perturb_data(theta_lam0_list, data_lam0_list, eps, n)

            model = train_score_model(theta_lam0_list, data_lam0_list, theta_pert_arr, x_pert_arr, y_pert_arr, 
                                      method=net, epoch=epoch, lr=lr, batch_size=batch_size)

            gradient_list, mean_grad_diff_sq = compute_est_score(model, theta_lam0_list, data_lam0_list, batch_size,x_1,x_2)
            lam0_variance = compute_variance(0, theta_lam0_list, data_lam0_list, gradient_list, batch_size)
            
            max_norm_x_list.append(max_norm_x)
            thetas_dict['lam0'].append(np.expand_dims(theta_lam0_list[T-1], axis=1))
            vars_dict['lam0'].append(lam0_variance)
            true_var_dict['lam0'].append(cov_total)
            beta_max_2_dict['lam0'].append(beta_max_2)
            
            theta_dict, est_cov_dict, true_cov_dict, beta_max_2 = compute_variance_lam(n, x_1, x_2, gradient_list)

            for m in methods[1:]:
                vars_dict[m].append(est_cov_dict[m])
                true_var_dict[m].append(true_cov_dict[m])
                thetas_dict[m].append(theta_dict[m])
                beta_max_2_dict[m].append(beta_max_2[m])

        for m in methods:
            avg_coverage_0, avg_width_0 = compute_coverage_and_width(vars_dict[m], thetas_dict[m], true_theta_t, n)
            print(f'\tmethod {m}, theta_t avg_coverage {avg_coverage_0}, avg_width {avg_width_0}')
            widths[m].append(avg_width_0)
            coverages[m].append(avg_coverage_0)

            avg_coverage_0_PS, avg_width_0_PS = compute_coverage_and_width_PS(vars_dict[m], thetas_dict[m], beta_max_2_dict[m], max_norm_x_list, theta0, theta_PS, n, mu, T, gamma)
            print(f'\tmethod {m}, theta_PS avg_coverage {avg_coverage_0_PS}, avg_width {avg_width_0_PS}')
            

            true_widths[m].append(avg_width_0_PS)
            true_coverages[m].append(avg_coverage_0_PS)

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    colors = ['#2ca02c', '#1f77b4', '#ff7f0e']
    for ax in axes:
        ax.set_prop_cycle(color=colors)

    for i,m in enumerate(methods):
        ci = colors[i]
        axes[0].plot(n_list, coverages[m], label=methods_name[i],color=ci, linestyle='-', linewidth=3)
        axes[0].plot(n_list, true_coverages[m], label=methods_name[i] + r" $\theta$_PS",color=ci, alpha=0.5, linestyle='--', linewidth=3)
        axes[1].plot(n_list, widths[m], label=methods_name[i], color=ci, linestyle='-', linewidth=3)
        axes[1].plot(n_list, true_widths[m], label=methods_name[i] + r" $\theta$_PS", color=ci, alpha=0.5, linestyle='--', linewidth=3)


    axes[0].axhline(0.9, color='gray', linestyle='--')       # target line
    axes[0].set_title('coverage', fontsize=18)
    axes[0].set_xlabel('n', fontsize=18)
    axes[0].set_ylim(0.59, 1.01)

    axes[1].set_title('width', fontsize=18)
    axes[1].set_xlabel('n', fontsize=18)
    axes[1].set_ylim(-0.01, 0.11)
    axes[1].legend(loc='upper right',fontsize=14,ncol=2)
    
    for ax in axes:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.tick_params(axis='both', labelsize=16)

    plt.tight_layout()
    plt.show()
    plt.savefig(f"{save_name}/CI_n_T{T}_sigmay{sigma_y2}_mu{c}_gamma{gamma}.pdf", dpi=200)

    # collect arrays into a dict for savez
    npz_fname = f"{save_name}/CI_n_T{T}_sigmay{sigma_y2}_mu{c}_gamma{gamma}.npz"
    save_dict = {'n_list': n_list}
    for m in methods:
        save_dict[f"{m}_true_width"] = true_widths[m]
        save_dict[f"{m}_true_coverage"] = true_coverages[m]
        save_dict[f"{m}_res_width"] = widths[m]
        save_dict[f"{m}_res_coverage"] = coverages[m]

    # write out .npz
    np.savez(npz_fname, **save_dict)
    print(f"Saved plotting data to {npz_fname}")


if __name__=='__main__':
    net = 'gaussian' # 'gaussian'  'DNN'
    save_name = f'results_{net}'
    os.makedirs(save_name, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    set_seed(42)

    d=2 
    T_list = [2,3,4]
    n_list = [100,200,400,600,800,1000]
    N=2000
    shift=-0.2
    theta0 = np.random.rand(d)
    methods = ['lam0','lam1','greedy']
    methods_name = [r'$\lambda=0$', r'$\lambda=1$','greedy']

    sigma_y2=0.2 
    a = 0.1 
    b = 1 
    c = 0.05 
    j = 1 
    
    gamma = 2  
    I = np.identity(d)

    eps = 0.2
    lr = 0.2  if net == 'gaussian' else 0.1
    epoch = 5 if net == 'gaussian' else 20
    repeat = 1000

    
    cov_matrix0 = np.identity(d)
    mean0 = np.ones(d).reshape(-1, 1)
    beta0 = np.random.randn(d).reshape(-1, 1)
    beta0 = beta0 / np.linalg.norm(beta0)
    mu0 = np.random.randn(d).reshape(-1, 1)


    for T in T_list:
        print(f'running exp with T{T}_meanx{a}_covx{j}_beta{b}_mu{c}_shift{shift}_sigmay{sigma_y2}_gamma{gamma}')
        cov_matrix, mean, beta, mu, gradient, Sigma_x, beta_t, mu_t, sigma_y = prepare_param(a,b,c,j,sigma_y2)

        run_exp()
    
    