import torch
import numpy as np
from scipy.linalg import toeplitz
import matplotlib.pyplot as plt
from tqdm import tqdm
DEVICE = 'cuda'
def generate_gaussian_data(n, p, mu, Sigma, phi, beta, device=DEVICE):
    mu = torch.tensor(mu, dtype=torch.float32, device=device)
    Sigma = torch.tensor(Sigma, dtype=torch.float32, device=device)
    X = torch.distributions.MultivariateNormal(mu, Sigma).sample((n,))
    noise = torch.randn(n, dtype=torch.float32, device=device) * phi
    y = X @ beta + noise
    return X, y

from scipy.optimize import root_scalar
import numpy as np

def solve_alpha(lambdas, n=400, p=150, n1=200):

    def fixed_point_eq(alpha):
        denom = lambdas * alpha + 1 - (p / n) - alpha
        lhs = np.sum(1.0 / denom)
        rhs = (p + n * alpha - n1) / (1 - (p / n) - alpha)
        return lhs - rhs

    lower = max(0.0, (n1 - p) / n + 1e-6)
    upper = 1 - p / n - 1e-6

    sol = root_scalar(fixed_point_eq, bracket=[lower, upper], method='bisect')
    if sol.converged:
        return sol.root
    else:
        raise RuntimeError("Failed to solve for alpha")
        
def empirical_covariance(x):
    x = np.asarray(x)
    x_centered = x - np.mean(x, axis=0)
    cov = x_centered.T @ x_centered / (x.shape[0] - 1)
    return cov
    
def calculate_alpha_estimate(cov_train,cov_augment,n1,n2,p):
    cov_orig = cov_train
    eigvals1, eigvecs1 = np.linalg.eigh(cov_orig)
    eps = 1e-6
    eigvals1 = np.maximum(eigvals1, eps)
    S1_inv_sqrt = eigvecs1 @ np.diag(1.0 / np.sqrt(eigvals1)) @ eigvecs1.T
    cov_aug = cov_augment
    M = S1_inv_sqrt @ cov_aug @ S1_inv_sqrt
    eigvals_transfer = np.linalg.eigvalsh(M)
    lambdas = np.sort(eigvals_transfer)[::-1]
    n = n1 + n2
    alpha_solution = solve_alpha(lambdas,n=n,p=p,n1=n1)
    estimate = (p + n * alpha_solution - n1) / (1 - (p / n) - alpha_solution)
    return alpha_solution, estimate

def least_squares_estimate(X, y):
    beta_hat, _ = torch.linalg.lstsq(X, y.unsqueeze(1))[:2]
    return beta_hat.squeeze()

def compute_excess_risk(X_test, beta_hat, beta):
    excess_risk = torch.mean((X_test @ beta_hat - X_test @ beta) ** 2).item()
    return excess_risk

def generate_mixture_2gaussian_data(n_1,n_2, mu1, Sigma1, mu2, Sigma2,phi, beta, device=DEVICE):
    mu1 = torch.tensor(mu1, dtype=torch.float32, device=device)
    mu2 = torch.tensor(mu2, dtype=torch.float32, device=device)
    Sigma1 = torch.tensor(Sigma1, dtype=torch.float32, device=device)
    Sigma2 = torch.tensor(Sigma2, dtype=torch.float32, device=device)
    d = mu1.shape[0]
    X = torch.empty((n_1+n_2, d), dtype=torch.float32, device=device)
    dist1 = torch.distributions.MultivariateNormal(mu1, Sigma1)
    dist2 = torch.distributions.MultivariateNormal(mu2, Sigma2)
    permutation = np.random.permutation(n_1+n_2)
    for i in range(n_1):
        X[permutation[i]] = dist1.sample()
    for i in range(n_2):
        X[permutation[i+n_1]] = dist2.sample()
    noise = torch.randn(n_1+n_2, dtype=torch.float32, device=device) * phi
    y = X @ beta + noise
    return X, y

def get_perpendicular_vector(v):
    v = v.flatten()
    d = v.shape[0]
    e = torch.zeros_like(v)
    e[0] = 1.0
    if torch.allclose(v / v.norm(), e / e.norm()):
        e[1] = 1.0
    proj = torch.dot(v, e) / torch.dot(v, v) * v
    u = e - proj
    u = u / u.norm() * v.norm()
    u = u.to(DEVICE)
    return u

def inverse_sqrt_matrix(Sigma, eps=1e-10):
    """
    Compute Sigma^{-1/2} for a positive semi-definite matrix Sigma.
    """
    eigvals, eigvecs = torch.linalg.eigh(Sigma)
    inv_sqrt_eigvals = torch.diag(1.0 / torch.sqrt(eigvals + eps))
    Sigma_inv_sqrt = eigvecs @ inv_sqrt_eigvals @ eigvecs.T

    return Sigma_inv_sqrt

cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
def calculate_statistics(n1,n2,p,mu_train,mu_augment,Cov_train,Cov_augment,sigma):
  n = n1 + n2
  v1 = sigma**2 /n * calculate_alpha_estimate(Cov_train,Cov_augment,n1,n2,p)[1]
  v2 = 0
  return v1,v2

def get_identity_covariance(p,scale=1,device=DEVICE):
  return torch.eye(p,device=device) * scale

def get_toeplitz_covariance(p,device=DEVICE):
  c = np.random.randn(p)
  T_np = toeplitz(c)
  return torch.tensor(T_np,dtype=torch.float32,device=device)

def get_circulant_covariance(p,device=DEVICE):
  c = np.zeros(p)
  c[0:3] = [3,2,1]
  c[-2:] = [1,2]
  c = torch.tensor(c,device=DEVICE)
  return torch.stack([c.roll(shifts=i) for i in range(p)]).to(dtype=torch.float32)

def get_spiked_covariance(p,mu,theta,device=DEVICE):
  i = torch.eye(p,device=device)
  spike = theta * torch.outer(mu,mu)
  return (i + spike).to(dtype=torch.float32)

def power_band_covariance(p, x):
    i = torch.arange(p, device=DEVICE).unsqueeze(0)
    j = torch.arange(p, device=DEVICE).unsqueeze(1)
    band = torch.abs(i - j)  # shape (p, p)
    powers = x ** (band)
    cov_org = 0.9 ** band
    M = inverse_sqrt_matrix(cov_org)
    shift = M @ powers @ M
    shift_trace = torch.trace(shift)/p
    powers = powers / shift_trace
    return powers.to(dtype=torch.float32)

def get_covariance(type, p, **kwargs):
  if type == "identity":
    return get_identity_covariance(p=p, scale=kwargs['scale'])
  elif type == "toeplitz":
    return get_toeplitz_covariance(p=p)
  elif type == "circulant":
    return get_circulant_covariance(p=p)
  elif type == "spiked_mu1":
    return get_spiked_covariance(p=p, mu=kwargs['mean1'], theta = kwargs['spike'])
  elif type == "spiked_mu2":
    return get_spiked_covariance(p=p, mu =kwargs['mean2'], theta = kwargs['spike'])
  elif type == "power_band":
    return power_band_covariance(p=p, x=kwargs['decay'])
  else:
    raise ValueError("Invalid covariance type")
  
def get_underparam_parameters(p=1000):
  n_1 = 2*p
  n_2 = 2*p
  phi = 1
  beta = torch.randn(p, device=DEVICE, dtype=torch.float32)
  return n_1, n_2, phi, beta

def get_halfover_parameters(p=1000):
  n_1 = p//2
  n_2 = 2*p
  phi = 1
  beta = torch.randn(p, device=DEVICE, dtype=torch.float32)
  return n_1, n_2, phi, beta

def get_means(p=1000,nrm=5):
  mean1 = torch.randn(p, device=DEVICE, dtype=torch.float32)
  mean1 /= torch.norm(mean1)
  mean1 *= nrm * np.sqrt(p)
  mean2 = get_perpendicular_vector(mean1).to(dtype=torch.float32)
  return mean1, mean2

def simulate_samecov(p, setting, covariance1, covariance2,\
                     mean1, mean2, sample_points, trials,\
                     band_decay=0.9, spike_amount=0.5, scale_train=1, scale_test=1, band_decay_test=0.9, spike_amount_test=0.4):
  n_1,n_2,phi,beta = setting(p=p)
  excess_risks_means = []
  excess_risks_confidence = []
  score_estimates = []
  alphas = np.linspace(0, 1, sample_points)
  for alpha in alphas:
    excess_risk_trial_means = []
    score_est_per_alpha = []
    for trial in tqdm(range(trials), desc=f"Alpha {alpha}"):
      beta = torch.randn(p, device=DEVICE, dtype=torch.float32)
      mean1, mean2 = get_means(p=p,nrm=2)
      augment_mean = alpha * mean1 + np.sqrt(1-alpha**2) * mean2
      training_covariance = band_decay*get_covariance(covariance1,p, mean1=mean1, mean2=mean2, decay=band_decay, spike=spike_amount, scale=scale_train)
      augment_covariance = band_decay_test*get_covariance(covariance2,p, mean1=mean1, mean2=mean2, decay=band_decay_test, spike=spike_amount_test, scale=scale_test)
      X_train, y_train = generate_mixture_2gaussian_data(n_1, n_2, mean1, training_covariance, augment_mean, augment_covariance, phi, beta)
      X_test, y_test = generate_gaussian_data(2*n_1, p, mean1, training_covariance, phi, beta)
      beta_hat = least_squares_estimate(X_train, y_train)
      if trial == 0:
        v1, v2 = calculate_statistics(n_1,n_2, p, mean1, augment_mean, training_covariance.cpu().numpy(), augment_covariance.cpu().numpy(), phi)
        score_est_per_alpha.append(v1 + v2)
      excess_risk = compute_excess_risk(X_test, beta_hat, beta)
      excess_risk_trial_means.append(excess_risk)
    score_estimates.append(np.mean(score_est_per_alpha))
    excess_risks_means.append(np.mean(excess_risk_trial_means))
    excess_risks_confidence.append(np.std(excess_risk_trial_means))

  return excess_risks_means, excess_risks_confidence, score_estimates


p=600
sample_points = 11
trials = 100
mean1, mean2 = get_means(p=p,nrm=2)
vals = [0.9,0.5,0.1]
scale1_exmean, scale1_exconf, scale1_scores = simulate_samecov(p,get_underparam_parameters, 'power_band','power_band',\
                                                                mean1, mean2, sample_points, trials,band_decay_test=vals[0], band_decay=0.9)

scale2_exmean, scale2_exconf, scale2_scores = simulate_samecov(p,get_underparam_parameters, 'power_band','power_band',\
                                                                mean1, mean2, sample_points, trials,band_decay_test=vals[1], band_decay=0.9)

scale3_exmean, scale3_exconf, scale3_scores = simulate_samecov(p,get_underparam_parameters, 'power_band','power_band',\
                                                                mean1, mean2, sample_points, trials,band_decay_test=vals[2], band_decay=0.9)

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 1, sample_points)

import seaborn as sns

sns.set(style="whitegrid", context="talk")
palette = sns.color_palette("mako", 3) 


plt.figure(figsize=(6, 5))
sns.lineplot(x=x, y=scale1_scores, label=rf'$\rho_s={{{vals[0]}}}$ (theory)',
             color=palette[0], linewidth=6, zorder=3)
sns.lineplot(x=x, y=scale2_scores, label=rf'$\rho_s={{{vals[1]}}}$ (theory)',
             color=palette[1], linewidth=6, zorder=3)
sns.lineplot(x=x, y=scale3_scores, label=rf'$\rho_s={{{vals[2]}}}$ (theory)',
             color=palette[2], linewidth=6, zorder=3)

plt.fill_between(x,
                 np.array(scale1_exmean) - np.array(scale1_exconf),
                 np.array(scale1_exmean) + np.array(scale1_exconf),
                 color=palette[0], alpha=0.25)
plt.fill_between(x,
                 np.array(scale2_exmean) - np.array(scale2_exconf),
                 np.array(scale2_exmean) + np.array(scale2_exconf),
                 color=palette[1], alpha=0.25)
plt.fill_between(x,
                 np.array(scale3_exmean) - np.array(scale3_exconf),
                 np.array(scale3_exmean) + np.array(scale3_exconf),
                 color=palette[2], alpha=0.25)

sns.scatterplot(x=x, y=scale1_exmean, label=rf'$\rho_s={{{vals[0]}}}$ (empirical)',
                color=palette[0], marker='o', s=140, zorder=4)
sns.scatterplot(x=x, y=scale2_exmean, label=rf'$\rho_s={{{vals[1]}}}$ (empirical)',
                color=palette[1], marker='o', s=140, zorder=4)
sns.scatterplot(x=x, y=scale3_exmean, label=rf'$\rho_s={{{vals[2]}}}$ (empirical)',
                color=palette[2], marker='o', s=140, zorder=4)



plt.xlabel('Cosine similarity of means', fontsize=26)
plt.ylabel('Excess risk', fontsize=26)
plt.legend(frameon=True, fontsize=14, loc='upper right', framealpha=0.6)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('<output path>', dpi=300)