import torch
import numpy as np
from scipy.linalg import toeplitz
import matplotlib.pyplot as plt
from tqdm import tqdm
DEVICE = 'cuda'
def generate_gaussian_data(n, p, mu, Sigma, phi, beta, device=DEVICE):
    mu = torch.tensor(mu, dtype=torch.float32, device=device)
    Sigma = torch.tensor(Sigma, dtype=torch.float32, device=device)
    X = torch.distributions.MultivariateNormal(mu, Sigma).sample((n,))
    noise = torch.randn(n, dtype=torch.float32, device=device) * phi
    y = X @ beta + noise
    return X, y

from scipy.optimize import root_scalar
import numpy as np

def solve_alpha(lambdas, n=400, p=150, n1=200):

    def fixed_point_eq(alpha):
        denom = lambdas * alpha + 1 - (p / n) - alpha
        lhs = np.sum(1.0 / denom)
        rhs = (p + n * alpha - n1) / (1 - (p / n) - alpha)
        return lhs - rhs

    lower = max(0.0, (n1 - p) / n + 1e-6)
    upper = 1 - p / n - 1e-6

    sol = root_scalar(fixed_point_eq, bracket=[lower, upper], method='bisect')
    if sol.converged:
        return sol.root
    else:
        raise RuntimeError("Failed to solve for alpha")
    
def custom_diagonal_matrix_vary(p, c=1.0, device="cuda:0"):
    assert p % 2 == 0, "p must be even for a clean split"

    k = torch.arange(p, device=device, dtype=torch.float32)
    diag = torch.where(
        k < p // 2,
        1 + (c * k / (2 * (p - 1))),
        1 - (c * k / (2 * (p - 1)))
    )

    return torch.diag(diag)

import torch

def custom_diagonal_matrix_fix(p, c=1.0, device="cuda:0"):
    assert p % 2 == 0, "p must be even for a clean split"

    k = torch.arange(1, p + 1, device=device, dtype=torch.float32)
    diag = torch.where(
        k <= p // 2,
        1 + c / 2,
        1 - c / 2
    )

    return torch.diag(diag)
        
def empirical_covariance(X):
    X = X - X.mean(dim=0)
    return (X.T @ X) / (X.shape[0] - 1)
    
def calculate_alpha_estimate(cov_train, cov_augment, n1, n2, p):
    eps = 1e-6
    eigvals1, eigvecs1 = torch.linalg.eigh(cov_train)
    eigvals1 = torch.clamp(eigvals1, min=eps)
    
    S1_inv_sqrt = eigvecs1 @ torch.diag(1.0 / torch.sqrt(eigvals1)) @ eigvecs1.T
    
    M = S1_inv_sqrt @ cov_augment @ S1_inv_sqrt
    eigvals_transfer = torch.linalg.eigvalsh(M)
    lambdas = torch.sort(eigvals_transfer, descending=True).values

    lambdas_np = lambdas.cpu().numpy()
    n = n1 + n2
    alpha_solution = solve_alpha(lambdas_np, n=n, p=p, n1=n1)
    estimate = (p + n * alpha_solution - n1) / (1 - (p / n) - alpha_solution)

    return alpha_solution, estimate

def least_squares_estimate(X, y):
    beta_hat, _ = torch.linalg.lstsq(X, y.unsqueeze(1))[:2]
    return beta_hat.squeeze()

def compute_excess_risk(X_test, beta_hat, beta, sign_label=False):
    excess_risk = torch.mean((X_test @ beta_hat - X_test @ beta) ** 2).item()
    if sign_label:
        # count number of equals
        y_test = X_test @ beta
        y_test = torch.sign(y_test)
        y_hat = X_test @ beta_hat
        y_hat = torch.sign(y_hat)
        # count number of correct predictions
        correct_predictions = torch.sum(y_test == y_hat).item()
        total_predictions = y_test.shape[0]
        accuracy = correct_predictions / total_predictions
        excess_risk = 1 - accuracy  # Convert to excess risk as 1 -
    return excess_risk

def generate_mixture_2gaussian_data(n_1, n_2, mu1, Sigma1, mu2, Sigma2, phi, beta, device=DEVICE):
    mu1 = mu1.to(dtype=torch.float32, device=device)
    mu2 = mu2.to(dtype=torch.float32, device=device)
    Sigma1 = Sigma1.to(dtype=torch.float32, device=device)
    Sigma2 = Sigma2.to(dtype=torch.float32, device=device)
    beta = beta.to(dtype=torch.float32, device=device)

    dist1 = torch.distributions.MultivariateNormal(mu1, Sigma1)
    dist2 = torch.distributions.MultivariateNormal(mu2, Sigma2)

    X1 = dist1.sample((n_1,))
    X2 = dist2.sample((n_2,))
    X = torch.cat([X1, X2], dim=0)

    perm = torch.randperm(n_1 + n_2, device=device)
    X = X[perm]

    noise = torch.randn(n_1 + n_2, dtype=torch.float32, device=device) * phi
    y = X @ beta + noise

    return X, y

def get_perpendicular_vector(v):
    v = v.flatten()
    d = v.shape[0]
    e = torch.zeros_like(v)
    e[0] = 1.0
    if torch.allclose(v / v.norm(), e / e.norm()):
        e[1] = 1.0
    proj = torch.dot(v, e) / torch.dot(v, v) * v
    u = e - proj
    u = u / u.norm() * v.norm()
    u = u.to(DEVICE)
    return u

def inverse_sqrt_matrix(Sigma, eps=1e-10):
    eigvals, eigvecs = torch.linalg.eigh(Sigma)
    inv_sqrt_eigvals = torch.diag(1.0 / torch.sqrt(eigvals + eps))
    Sigma_inv_sqrt = eigvecs @ inv_sqrt_eigvals @ eigvecs.T

    return Sigma_inv_sqrt

cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

def calculate_statistics(n1,n2,p,mu_train,mu_augment,Cov_train,Cov_augment,sigma):
  n = n1 + n2
  v1 = sigma**2 /n * calculate_alpha_estimate(Cov_train,Cov_augment,n1,n2,p)[1]
  v2 = 0
  return v1,v2

def get_identity_covariance(p,scale=1,device=DEVICE):
  return torch.eye(p,device=device) * scale

def get_toeplitz_covariance(p,device=DEVICE):
  c = np.random.randn(p)
  T_np = toeplitz(c)
  return torch.tensor(T_np,dtype=torch.float32,device=device)

def get_circulant_covariance(p,device=DEVICE):
  c = np.zeros(p)
  c[0:3] = [3,2,1]
  c[-2:] = [1,2]
  c = torch.tensor(c,device=DEVICE)
  return torch.stack([c.roll(shifts=i) for i in range(p)]).to(dtype=torch.float32)

def get_spiked_covariance(p,mu,theta,device=DEVICE):
  i = torch.eye(p,device=device)
  spike = theta * torch.outer(mu,mu)
  return (i + spike).to(dtype=torch.float32)

def power_band_covariance(p, x):
    i = torch.arange(p, device=DEVICE).unsqueeze(0)
    j = torch.arange(p, device=DEVICE).unsqueeze(1)
    band = torch.abs(i - j)  # shape (p, p)
    powers = x ** (band)
    cov_org = 0.9 ** band
    M = inverse_sqrt_matrix(cov_org)
    shift = M @ powers @ M
    shift_trace = torch.trace(shift)/p
    powers = powers / shift_trace
    return powers.to(dtype=torch.float32)

def get_covariance(type, p, **kwargs):
  if type == "identity":
    return get_identity_covariance(p=p, scale=kwargs['scale'])
  elif type == "toeplitz":
    return get_toeplitz_covariance(p=p)
  elif type == "circulant":
    return get_circulant_covariance(p=p)
  elif type == "spiked_mu1":
    return get_spiked_covariance(p=p, mu=kwargs['mean1'], theta = kwargs['spike'])
  elif type == "spiked_mu2":
    return get_spiked_covariance(p=p, mu =kwargs['mean2'], theta = kwargs['spike'])
  elif type == "power_band":
    return power_band_covariance(p=p, x=kwargs['decay'])
  elif type == "custom_diagonal_vary":
    return custom_diagonal_matrix_vary(p=p, c=kwargs['scale'])
  elif type == "custom_diagonal_fix":
    return custom_diagonal_matrix_fix(p=p, c=kwargs['scale'])
  else:
    raise ValueError("Invalid covariance type")
  
def get_underparam_parameters(p=1000):
  n_1 = 2*p
  n_2 = 2*p
  phi = 1
  beta = torch.randn(p, device=DEVICE, dtype=torch.float32)
  return n_1, n_2, phi, beta

def get_halfover_parameters(p=1000):
  n_1 = p//2
  n_2 = 2*p
  phi = 1
  beta = torch.randn(p, device=DEVICE, dtype=torch.float32)
  return n_1, n_2, phi, beta

def get_means(p=1000,nrm=5,random_mean=False):
  mean1 = torch.randn(p, device=DEVICE, dtype=torch.float32)
  mean1 /= torch.norm(mean1)
  mean1 *= nrm * np.sqrt(p)
  if random_mean:
     mean2 = torch.randn(p, device=DEVICE, dtype=torch.float32)
     mean2 /= torch.norm(mean2)
     mean2 *= nrm * np.sqrt(p)
  else:
    mean2 = get_perpendicular_vector(mean1).to(dtype=torch.float32)
  return mean1, mean2

def convert_to_sign(vector):
    return torch.sign(vector).to(dtype=torch.float32)

def simulate_samecov(p, setting, covariance1, covariance2,\
                     mean1, mean2, sample_points, trials,\
                     band_decay=0.9, spike_amount=0.5, scale_train=1, scale_test=1, band_decay_test=0.9, spike_amount_test=0.4,sign_label=False):
    n_1,n_2,phi,beta = setting(p=p)
    excess_risks = {}
    score_estimates = {}
    alphas = np.linspace(0.5, 5, sample_points)
    for trial in tqdm(range(trials), desc=f"Alpha"):
        for alpha in alphas:
            mean1, mean2 = get_means(p=p,nrm=2, random_mean=True)
            augment_mean = mean2
            beta = torch.randn(p, device=DEVICE, dtype=torch.float32)
            training_covariance = get_covariance(covariance1,p, mean1=mean1, mean2=mean2, decay=band_decay, spike=spike_amount, scale=scale_train)
            augment_covariance = get_covariance(covariance2,p, mean1=mean1, mean2=augment_mean, decay=band_decay_test, spike=spike_amount_test, scale=alpha)
            augment_covariance = augment_covariance * alpha
            X_train, y_train = generate_mixture_2gaussian_data(n_1, n_2, mean1, training_covariance, augment_mean, augment_covariance, phi, beta)
            if sign_label:
                y_train = convert_to_sign(y_train)
            X_test, y_test = generate_gaussian_data(2*n_1, p, mean1, training_covariance, phi, beta)
            if sign_label:
                y_test = convert_to_sign(y_test)
            beta_hat = least_squares_estimate(X_train, y_train)
            v1, v2 = calculate_statistics(n_1,n_2, p, mean1, augment_mean, training_covariance, augment_covariance, phi)
            excess_risk = compute_excess_risk(X_test, beta_hat, beta, sign_label)
            if alpha not in excess_risks:
                excess_risks[alpha] = []
            excess_risks[alpha].append(excess_risk)
            if alpha not in score_estimates:
                score_estimates[alpha] = []
            score_estimates[alpha].append(v1 + v2)
    excess_risks_means = []
    excess_risks_confidence = []
    score_estimates_means = []
    for alpha in alphas:
        if alpha in excess_risks:
            risks = excess_risks[alpha]
            excess_risks_means.append(np.mean(risks))
            excess_risks_confidence.append(np.std(risks))
        else:
            excess_risks_means.append(0)
            excess_risks_confidence.append(0)
        if alpha in score_estimates:
            scores = score_estimates[alpha]
            score_estimates_means.append(np.mean(scores))
        else:
            score_estimates_means.append(0)

    print(len(excess_risks_means), len(excess_risks_confidence), len(score_estimates_means))
    return excess_risks_means, excess_risks_confidence, score_estimates_means


p=600
sample_points = 11
trials = 100
mean1, mean2 = get_means(p=p,nrm=2)
vals = [0.9,0.5,0.1]
scale1_exmean, scale1_exconf, scale1_scores = simulate_samecov(p,get_underparam_parameters, 'power_band','power_band',\
                                                                mean1, mean2, sample_points, trials,band_decay_test=vals[0], band_decay=0.9)

scale2_exmean, scale2_exconf, scale2_scores = simulate_samecov(p,get_underparam_parameters, 'power_band','power_band',\
                                                                mean1, mean2, sample_points, trials,band_decay_test=vals[1], band_decay=0.9)

scale3_exmean, scale3_exconf, scale3_scores = simulate_samecov(p,get_underparam_parameters, 'power_band','power_band',\
                                                                mean1, mean2, sample_points, trials,band_decay_test=vals[2], band_decay=0.9)


x = np.linspace(0.5, 5, sample_points)

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

sns.set(style="whitegrid", context="talk")

palette = sns.color_palette("mako", 3)  

plt.figure(figsize=(6, 5))

sns.lineplot(x=x, y=scale1_scores, label=rf'$\rho_s={{{vals[0]}}}$ (theory)',
             color=palette[0], linewidth=6, zorder=3)
sns.lineplot(x=x, y=scale2_scores, label=rf'$\rho_s={{{vals[1]}}}$ (theory)',
             color=palette[1], linewidth=6, zorder=3)
sns.lineplot(x=x, y=scale3_scores, label=rf'$\rho_s={{{vals[2]}}}$ (theory)',
             color=palette[2], linewidth=6, zorder=3)

plt.fill_between(x,
                 np.array(scale1_exmean) - np.array(scale1_exconf),
                 np.array(scale1_exmean) + np.array(scale1_exconf),
                 color=palette[0], alpha=0.25)
plt.fill_between(x,
                 np.array(scale2_exmean) - np.array(scale2_exconf),
                 np.array(scale2_exmean) + np.array(scale2_exconf),
                 color=palette[1], alpha=0.25)
plt.fill_between(x,
                 np.array(scale3_exmean) - np.array(scale3_exconf),
                 np.array(scale3_exmean) + np.array(scale3_exconf),
                 color=palette[2], alpha=0.25)

sns.scatterplot(x=x, y=scale1_exmean, label=rf'$\rho_s={{{vals[0]}}}$ (empirical)',
                color=palette[0], marker='o', s=140, zorder=4)
sns.scatterplot(x=x, y=scale2_exmean, label=rf'$\rho_s={{{vals[1]}}}$ (empirical)',
                color=palette[1], marker='o', s=140, zorder=4)
sns.scatterplot(x=x, y=scale3_exmean, label=rf'$\rho_s={{{vals[2]}}}$ (empirical)',
                color=palette[2], marker='o', s=140, zorder=4)


plt.xlabel('Covariance scale', fontsize=26)
plt.ylabel('Excess risk', fontsize=26)
plt.legend(frameon=True, fontsize=14, loc='upper right', framealpha=0.6)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('<output path>', dpi=300)
