import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import random
import itertools
import math

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class Custom_BN(nn.Module):
    def __init__(self, num_features):
        super(Custom_BN, self).__init__()
        self.num_features = num_features
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        batch_mean = x.mean(dim=0, keepdims=True)
        batch_var = x.var(dim=0, unbiased=False, keepdims=True)
        return self.weight.unsqueeze(0)*((x - batch_mean) / torch.sqrt(batch_var + 1e-5)) + self.bias.unsqueeze(0)


class Poly(nn.Module):
    def __init__(self, power = 2):
        super().__init__()
        self.power = power

    def forward(self, input):
        return input**self.power

def weight_update_(out_model, model, mult, old_model, mult2):
    """
    updated pytorch module weights by multiplying by a constant and subtracting another constant from a model with same architecture.  model[param] = mult * (model[param] - mult2 * sub_model[param])
    """
    for out_param, param, old_param in zip(out_model.parameters(), model.parameters(), old_model.parameters()):
        out_param.data.copy_(mult * (param.data - mult2 * old_param))

def noise_indices(y, noise):
    if noise > 0:
        noised_indices = torch.rand(y.shape[0]) < noise
        noises = torch.randn(sum(noised_indices), requires_grad=False).sign()
        y[noised_indices] = noises
    return y

def get_batch_sparse_parity(n, params, batch_size=32, noise=0.0):
    """ 
    params is a dictionary with key 'k'
    """
    k = params['k']
    x = torch.randn(batch_size, n, requires_grad=False).sign()
    y = torch.prod(x[:,:k], dim=1)
    y = noise_indices(y, noise)
    if not params['use_pm1']:
      x[x<0] = 0
      y[y<0] = 0
    elif not params['label_pm1']:
        y[y<0] = 0
    return x, y



def get_batch_gaussian(n, params, batch_size, noise=0.0):
    """
    params is a dictionary with key 'n'
    """
    ws = params['ws']
    act_fn = params['act_fn']
    add_spike = params['add_spike']
    H_sqrt = params['H_sqrt']

    X = torch.randn(batch_size, n, requires_grad=False)
    if H_sqrt is not None:
        X = torch.matmul(X, H_sqrt)
    if add_spike > 0:
        spike_dir = torch.randn(1, n)
        X += add_spike * spike_dir
    
    # ws: (n, k) -- k features
    z = torch.matmul(X, ws)
    y = act_fn(z).mean(dim=1)
    if noise > 0:
        y += torch.randn_like(y, requires_grad=False) * noise
    return X, y

    
def get_batch_gaussian_sanity(n, params, batch_size, noise=0.0):
    w = params['w']
    k = len(w)
    noise_rank = params['noise_rank']
    noise_sigma = params['noise_sigma']

    r = torch.ones(n - k + noise_rank)
    R = torch.matmul(r.reshape(-1, 1), r.reshape(1, -1))

    coeff = torch.randn(batch_size, k, requires_grad=False)

    coeff_noise = coeff[:, -noise_rank:] * noise_sigma
    bases = torch.eye(n-k+noise_rank, noise_rank, requires_grad=False).T
    X_noise = torch.matmul(coeff_noise, bases)
    X_noise = torch.matmul(X_noise, R)
    X = torch.cat([coeff[:, :-noise_rank], X_noise], dim=1)

    y = torch.matmul(coeff, w)
    if noise > 0:
        y += torch.randn_like(y, requires_grad=False) * noise

    return X, y


def get_batch_staircase(n, params, batch_size=32, noise=0.0):
    """
    params is a dictionary with keys 'k_i' and values tuples of starting and stopping indices
    Example: params = {'k_1': (0,4), 'k_2': (4,8), 'k_3': (8,12)}

    Example input/output:
    n = 12
    params = {'k_1': (0,4), 'k_2': (4,8), 'k_3': (8,12)}
    batch_size = 2
    noise = 0.0

    Returns:
    x = tensor([[ 1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1,  1],
                [-1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1,  1]])
    y = tensor([3., 1.])  # Sum of products: (1*1*-1*1) + (-1*1*1*-1) + (1*-1*1*1) = 1 + 1 + 1 = 3
                         # Sum of products: (-1*1*1*-1) + (1*-1*-1*1) + (-1*1*-1*1) = 1 + 1 - 1 = 1
    """
    x = torch.randn(batch_size, n, requires_grad=False).sign()
    y = torch.zeros(batch_size, requires_grad=False)
    # Sum parity over each group of bits
    for v in params.values():
        start, stop = v  # Now this line will work because v is a tuple
        y += torch.prod(x[:,start:stop], dim=1)
    if noise > 0:
        y += torch.randn(batch_size, requires_grad=False).sign() * noise
    return x, y

def get_ds_sparse_parity(n, params, noise=0.0):
    """ 
    params is a dictionary with key 'k'
    """
    k = params['k']
    x = np.array([1, -1])
    arrs = []
    for _ in range(n):
        arrs.append(x.copy())
    x = torch.from_numpy(np.array(list(itertools.product(*arrs)))).type(torch.FloatTensor)
    x.requires_grad=False
    y = torch.prod(x[:,:k], dim=1)
    y = noise_indices(y, noise)
    return x, y

def get_ds_staircase(n, params, noise=0.0):
    """
    params is a dictionary with keys 'k_i' and values tuples of starting and stopping indices of staircase function
    """
    x = np.array([1, -1])
    arrs = []
    for _ in range(n):
        arrs.append(x.copy())
    x = torch.from_numpy(np.array(list(itertools.product(*arrs)))).type(torch.FloatTensor)
    x.requires_grad=False
    y = torch.zeros(x.shape[0], requires_grad=False)
    for v in params.values():
        start, stop = v
        y += torch.prod(x[:,start:stop], dim=1)
    if noise > 0:
        y += torch.randn(x.shape[0], requires_grad=False).sign() * noise
    return x, y

def create_mlp(n, width, depth, output_dim=1, mean_red_factor=0.0, 
               act='relu', power=2, gain=1.0, use_bn=False,
               use_hidden_bias=True, skip_output_layer=False):
    layers = []
    for i in range(depth):
        if i == 0:
            layers.append(torch.nn.Linear(n, width, bias=use_hidden_bias))
        else:
            layers.append(torch.nn.Linear(width, width, bias=use_hidden_bias))
        if use_bn:
            layers.append(Custom_BN(width))
        if act == 'relu':
            layers.append(torch.nn.ReLU())
        elif 'poly' in act:
            torch.nn.init.xavier_uniform_(layers[-1].weight, gain=gain)
            layers.append(Poly(power=power))
        elif act == 'tanh':
            layers.append(torch.nn.Tanh())
        elif act == 'none':
            pass
        else:
            raise ValueError(f"Activation function {act} not supported")
    if not skip_output_layer:
        layers.append(torch.nn.Linear(width, output_dim, bias=False))
    if 'poly' in act:
        torch.nn.init.xavier_uniform_(layers[-1].weight, gain=gain)
    return torch.nn.Sequential(*layers)


def show_results_(
        results,
        alpha=0.5,
        max_error=None
    ):
    """
    Given dictionary of results, plots the loss curves over training

    Args:

    results: dictionary with keys 'losses', 'accs', 'ema_accs', 'ou_ema_accs', 'problem_type', 'n', 'problem_params', 'width', 'depth'
    alpha: transparency of lines
    max_error: maximum value for y-axis
    """

    problem_type = results['problem_type']

    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    ax.plot(results["accs"], label="acc", alpha=alpha)
    ax.plot(results["ema_accs"], label="ema_acc", alpha=alpha)
    if 'ou_filter' in results.keys():
        if results['ou_filter']:
            ax.plot(results["ou_ema_accs"], label="ou_ema_acc", alpha=alpha, color='red')
    if 'ou_mle_filter' in results.keys():
        if results['ou_mle_filter']:
            ax.plot(results["ou_mle_accs"], label="ou_mle_acc", alpha=alpha, color='black')
    if max_error is not None:
        ax.set_ylim(0, max_error)
    ax.legend()
    ax.set_ylabel("acc")
    ax.set_xlabel("step")

    if problem_type == 'sparse_parity':
        title = f'n: {results["n"]}, k: {results["problem_params"]["k"]}, width {results["width"]}, depth: {results["depth"]}'
    elif problem_type == 'staircase':
        title = f'n: {results["n"]}, staircase: {list(results["problem_params"].values())}, width {results["width"]}, depth: {results["depth"]}'
    else:
        raise ValueError(f"problem_type {problem_type} not recognized")
    fig.suptitle(title)

    return fig, ax


def seed_everything(seed=13337):
    """
    Seeds random number generators for reproducibility
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def make_weight_dictionary(sequential_model, iters, save_bias=True, device='cpu'):
    """
    helper for constructing dictionary of weights for storing throughout training
    """
    
    weight_dictionary = {}
    for i in range(len(sequential_model)):
        if i % 2 == 0:
            dim_out, dim_in = sequential_model[i].weight.data.shape
            weight_dictionary[f'weight_{i // 2}'] = torch.zeros((iters, dim_out, dim_in), device=device)
            if save_bias:
                weight_dictionary[f'bias_{i // 2}'] = torch.zeros((iters, dim_out), device=device)
    return weight_dictionary


def store_weights_(weight_dictionary, model, iter, save_bias=True):
    """
    helper for storing weights throughout training
    """
    
    for i in range(len(model)):
        if i % 2 == 0:
            weight_dictionary[f'weight_{i // 2}'][iter,:,:]  = model[i].weight.data 
            if save_bias:
                weight_dictionary[f'bias_{i // 2}'][iter,:] = model[i].bias.data


def generate_covariance(dim, spectrum_type='log_decay', alphas=[], poly_decay_degree=2, exponential_decay_base=2):
    """
    Generate a covariance matrix with decaying eigenvalues.
    
    Args:
        dim: Dimension of the covariance matrix
        spectrum_type: Type of eigenvalue decay ('log_decay', 'poly_decay', etc.)
        
    Returns:
        H: Covariance matrix
        eigvals: Eigenvalues
        eigvecs: Eigenvectors
    """
    if len(alphas):
        eigvals = torch.tensor(alphas) 
    else:
        if spectrum_type == 'log_decay':
            # Generate eigenvalues with log decay, with max value of 1
            eigvals = torch.exp(-torch.arange(dim).float() / (dim / np.log(dim)))
            # Ensure max eigenvalue is 1
            eigvals = eigvals / eigvals[0]
            # Generate random orthogonal matrix for eigenvectors
            eigvecs = torch.nn.init.orthogonal_(torch.randn(dim, dim))
        elif spectrum_type == 'poly_decay':
            # Polynomial decay: λ_i = 1/i
            eigvals = 1.0 / (1.0 + torch.arange(dim).float())**poly_decay_degree
            # Ensure max eigenvalue is 1
            eigvals = eigvals / eigvals[0]
            # Generate random orthogonal matrix for eigenvectors
            eigvecs = torch.nn.init.orthogonal_(torch.randn(dim, dim))
        elif spectrum_type == 'exponential_decay':
            eigvals = 1.0 / (exponential_decay_base ** (torch.arange(dim).float()))
            # Ensure max eigenvalue is 1
            eigvals = eigvals / eigvals[0]
            # Generate random orthogonal matrix for eigenvectors
            eigvecs = torch.nn.init.orthogonal_(torch.randn(dim, dim))
        elif spectrum_type == 'random':
            log10_cond = 6
            rank_frac = 1.0
            # 1.  Eigenvalues: log-space from 1   down to   10^{-log10_cond}
            active_dim = math.ceil(rank_frac * dim)
            eigvals_full = torch.logspace(0, -log10_cond, steps=active_dim).float()
            if rank_frac < 1.0:                             # pad with tiny values
                pad = torch.full((dim - active_dim,), 10.0**(-log10_cond)).float()
                eigvals_full = torch.cat([eigvals_full, pad])

            # 2.  Randomly permute so “big” directions aren’t axis-aligned
            perm = torch.randperm(dim)
            eigvals = eigvals_full[perm]

            # 3.  Random orthogonal eigenvectors
            eigvecs = torch.linalg.qr(torch.randn(dim, dim)).Q

            cond = eigvals.max() / eigvals.min()
            print(f"constructed Sigma with condition number ≈ {cond:,.1e}")
        else:
            raise ValueError(f"Unknown spectrum type: {spectrum_type}")
    
    # Create covariance matrix H = Q * Λ * Q^T
    H = eigvecs @ torch.diag(eigvals) @ eigvecs.T
    
    # Verify trace property: Tr(H)/d ≈ log(d)/d
    trace = eigvals.sum().item()
    expected = np.log(dim)
    print(f"Tr(H)/d = {trace/dim:.4f}, log(d)/d = {expected/dim:.4f}")
    
    return H, eigvals, eigvecs


def get_teacher_student_batch(params, batch_size=32, noise=0.0, scale=10.0):
    """
    Generate a batch of data for the teacher-student problem.

    Args:
        params: Dictionary containing covariance matrix, true model, etc.
        batch_size: Number of samples to generate
        noise: Noise level for targets
        
    Returns:
        x: Input data
        y: Target data
    """
    # Extract parameters
    covariance = params['covariance']
    teacher_model = params['teacher_model']
    device = params['device']
    X_DATATYPE = params['X_DATATYPE']
    # Generate inputs x ~ N(0, H)
    mean = torch.zeros(covariance.shape[0], device=covariance.device, dtype=covariance.dtype)
    x = torch.distributions.MultivariateNormal(mean, covariance).sample((batch_size,))  
    
    # Generate targets using the teacher model
    x_ = x.type(X_DATATYPE).to(device)
    teacher_model = teacher_model.to(device)
    with torch.no_grad():
        y = teacher_model(x_) * scale
    
    return x, y



def get_batch_balanced_logistic(n, params, batch_size=32, noise=0.0):
    """
    Generate a batch of data for the balanced logistic loss problem.
    This is a multi-dimensional version of the problem where:
    - Half of the samples have label y=1
    - Half of the samples have label y=0
    - The model predicts with sigma(alpha*||w||^2)
    
    Args:
        n: Input dimension
        params: Dictionary with keys 'alpha' (scaling parameter)
        batch_size: Number of samples to generate
        noise: Noise level (not used in this task)
        
    Returns:
        x: Input data - all ones
        y: Target labels - each dim has p fractions of 1s and 1-p fractions of 0s
    """
    # Ensure batch_size is even for balanced classes
    p = params['p']
    use_samples = params['use_samples']
    
    if use_samples:
        # Sample one-hot inputs, and sample the labels, w.p. p for each dim
        x = torch.zeros(batch_size, n, device=device)
        samples = torch.randint(0, n, (batch_size,))
        x[torch.arange(batch_size), samples] = 1
        
        y = torch.bernoulli(torch.ones(batch_size, n) * p)
    else:
        # Get the full population
        x = torch.zeros(n * batch_size, n, device=device)
        y = torch.zeros(n * batch_size, n, device=device)
        per_dim_y = torch.cat([torch.ones(int(batch_size*p)), torch.zeros(int(batch_size*(1-p)))], dim=0).to(device)
        for i in range(n):
            row_indices = torch.arange(i * batch_size, (i+1) * batch_size)
            x[row_indices, i] = 1
            y[row_indices, i] = per_dim_y
            
    return x, y


def create_balanced_logistic_net(n, alphas=[1.0], initialization_scale=1.0, activation='none',
                                vector_version=0):
    """
    Creates a custom neural network for the balanced logistic task.
    
    The network computes sigma(alpha*||w||^2) where:
    - sigma is the sigmoid function: sigma(z) = 1/(1+e^(-z))
    - w is the weight vector
    - alpha is a scaling parameter
    
    This is a multi-dimensional version of the task where we compute:
    - ŷ(w) = sigma(alpha*||w||^2)
    - L(w) = log(1 + exp(alpha*||w||^2)) - 0.5 * alpha * ||w||^2
    
    Architecture:
    1. First linear layer with diagonal initialization (n x n)
    2. Second linear layer with diagonal initialization (n x n)
    3. Scale by alpha
    4. Apply sigmoid
    
    Args:
        n: Input dimension
        alpha: Scaling parameter
        initialization_scale: Scale factor for weight initialization
        
    Returns:
        model
    """
    class ScaleLayer(nn.Module):
        def __init__(self, alphas):
            super().__init__()
            self.alphas = torch.tensor(alphas)
            # self.ones = torch.ones_like(self.alphas) / len(self.alphas)
            self.alphas = torch.diag(self.alphas)
            self.alphas = self.alphas.to(device) 

            # self.ones = torch.nn.Parameter(self.ones, requires_grad=False)
            # make alphas a parameter, but does not require grad
            # self.alphas = torch.nn.Parameter(self.alphas, requires_grad=False)
            
        def forward(self, x):
            # return (x @ self.alphas) / len(self.alphas)
            return x @ self.alphas
    
    if activation == 'relu':
        act_fn = torch.relu
    elif activation == 'tanh':
        act_fn = torch.tanh
    elif activation == 'sigmoid':
        act_fn = torch.sigmoid
    elif activation == 'none':
        act_fn = lambda x: x

    # Custom layer to compute the squared norm
    class SquaredNormLayer(nn.Module):
        def __init__(self, input_dim, act_fn, scale=1.0):
            super().__init__()
            self.weight1 = nn.Parameter(torch.eye(input_dim) * scale)
            self.weight2 = nn.Parameter(torch.eye(input_dim) * scale)
            
        def forward(self, x):
            h = act_fn(torch.matmul(x, self.weight1.t()))
            out = torch.matmul(h, self.weight2.t())
            return out
    
    class VectorSquaredNormLayer(nn.Module):
        def __init__(self, input_dim, act_fn, scale=1.0):
            super().__init__()
            self.weight1 = nn.Parameter(torch.ones(input_dim) * scale)
            self.weight2 = nn.Parameter(torch.ones(input_dim) * scale)

        def forward(self, x):
            h = act_fn(x * self.weight1)
            out = h * self.weight2
            return out
    
    # Create a model with diagonal initialization
    if vector_version == 0:
        model = nn.Sequential(
            SquaredNormLayer(n, act_fn, initialization_scale),  # Custom squared norm computation
            ScaleLayer(alphas),                          # Scale by alpha
        )
    else:
        model = nn.Sequential(
            VectorSquaredNormLayer(n, act_fn, initialization_scale),  # Custom squared norm computation
            ScaleLayer(alphas),                          # Scale by alpha 
        )
    
    return model
