import numpy as np
import torch
import torch.nn as nn
import cvxpy as cp
import matplotlib.pyplot as plt
import wfdb
import random

##################################################
# Straight-through sign
##################################################

class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return torch.where(x >= 0, 1.0, -1.0)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

def sign(x):
    return Sign.apply(x)

def sign_np(x):
    return np.where(x >= 0, 1.0, -1.0)


##################################################
# Parallel Network
##################################################

class ParallelNet(nn.Module):
    def __init__(self, d, widths, mL, sigma=sign, fixed_alpha=False, fixed_s=False):
        super().__init__()

        self.sigma = sigma
        self.mL = mL

        layer_sizes = [d] + widths + [1]

        self.W = nn.ParameterList()
        self.b = nn.ParameterList()
        self.s = nn.ParameterList()

        for fan_in, fan_out in zip(layer_sizes[:-1], layer_sizes[1:]):

            self.W.append(nn.Parameter(torch.randn(mL, fan_in, fan_out)))
            self.b.append(nn.Parameter(torch.randn(mL, 1, fan_out)))

            if fixed_s:
                self.s.append(nn.Parameter(torch.ones(mL, 1, fan_out)))
            else:
                self.s.append(nn.Parameter(torch.randn(mL, 1, fan_out)))

        if fixed_alpha:
            self.register_buffer("alpha", torch.ones(mL))
        else:
            self.alpha = nn.Parameter(torch.randn(mL))

    def forward(self, x):

        if x.ndim == 1:
            x = x[:, None]

        N = x.shape[0]

        h = x.unsqueeze(0).expand(self.mL, N, x.shape[1])

        for W, b, s in zip(self.W, self.b, self.s):
            h = torch.matmul(h, W)
            h = self.sigma(s * h + b)

        h = h.squeeze(-1)

        return (h * self.alpha[:, None]).sum(dim=0)


##################################################
# Feature sampling
##################################################

def sample_parallel_feature(X, d, widths, seed):

    torch.manual_seed(seed)

    model = ParallelNet(d=d, widths=widths, mL=1, fixed_alpha=True, fixed_s=True)

    X_t = torch.tensor(X, dtype=torch.float32)

    with torch.no_grad():
        h = model(X_t)

    params = {
        "W": [w.clone() for w in model.W],
        "b": [b.clone() for b in model.b],
    }

    return h.numpy().flatten(), params


##################################################
# Dictionary
##################################################

def build_dictionary(X, d, widths, M, seed=0, do3lyr1d=False, do2lyr1d=False):

    if do2lyr1d:
        return build_full_2lyr_1d_dictionary(X)
    if do3lyr1d:
        return build_3lyr_1d_dictionary(X=X,m1=widths[0],M=M,seed=seed)

    rng = np.random.default_rng(seed)
    seeds = rng.integers(0, 10_000_000, size=M)

    A, store = [], []

    for s in seeds:
        h, p = sample_parallel_feature(X, d, widths, int(s))
        A.append(h)
        store.append(p)

    A = np.stack(A, axis=1)

    A_T = A.T
    A_unique, idx = np.unique(A_T, axis=0, return_index=True)

    return A_unique.T, [store[i] for i in idx]

def w1_from_3lyr_feat(m1, d=1):
    return torch.ones(1,d,m1)

def w2_from_3lyr_feat(m1, d=1):
    w2 = torch.ones(1,m1,1)
    for i in range(m1):
        if i % 2 !=0:
            w2[0,i,0] = -1
    return w2

def b1_from_3lyr_feat(h, X, m1):
    assert len(h) == len(X)
    N = len(h) 

    b1 = torch.ones(1, 1, m1)*(-X[N-1]) #if no switches then all-ones so should subtract smallest x elt
    i = 0
    for n in range(N):
        if n==0: continue
        if h[n] != h[n-1]: #i^th switch occurs at index n in h
            assert i < m1
            b1[0,0,i] = -X[n-1].item()
            i = i+1
    
    while(i<m1): #fill in remaining as the same
        b1[0,0,i] = -X[N-1].item()
        i = i+1

    return b1
        

def b2_from_3lyr_feat(m1):
    if m1 % 2 ==0:
        return torch.zeros(1, 1, 1)
    else: 
        return -1*torch.ones(1, 1, 1)
                      

def build_3lyr_1d_dictionary(X,m1,M,seed=0):

    #print('X=',X) #debug
    
    N = len(X)
    rng = np.random.default_rng(seed)
    seeds = rng.integers(0, 10_000_000, size=M)
    
    A, store = [], []

    for s in seeds:

        k = random.randint(0, m1)
        switch_points = sorted(random.sample(range(1, N), k))

        h = np.ones(N,)
        current = 1
        prev = 0
        for switch_idx in switch_points:
            h[prev:switch_idx] = current
            current *= -1
            prev = switch_idx
        
        h[prev:] = current

        A.append(h) #make sure consistent with sample_parallel_feature

        params = {
            "W": [w1_from_3lyr_feat(m1), w2_from_3lyr_feat(m1)],
            "b": [b1_from_3lyr_feat(h, X, m1), b2_from_3lyr_feat(m1)]
        }
        store.append(params)

        #print('h=', h) #debug
        #print('params=',  params) #debug

    A = np.stack(A, axis=1)

    return A, store

def build_full_2lyr_1d_dictionary(X): 
    A = sign_np(X - X.T) #check X dim, see if this makes matrix
    store = []
    for x in X:
        params = {
            "W": [torch.ones(1, 1, 1)],
            "b": [torch.tensor(-x).view(1, 1, 1)],
        }

        store.append(params) #check this gives consistent type, form as sample_parallel_feature for 2lr, 1d nn
        
    return A, store


##################################################
# Lasso (convex)
##################################################

def solve_lasso(A, y, beta):

    z = cp.Variable(A.shape[1])

    obj = cp.Minimize(
        0.5 * cp.sum_squares(A @ z - y)
        + beta * cp.norm1(z)
    )

    prob = cp.Problem(obj)
    prob.solve(solver=cp.SCS)

    return z.value, prob.value


##################################################
# Reconstruction
##################################################

def reconstruct_parallel_net(store, z, d, widths, tol=1e-10):

    mask = np.abs(z) > tol
    z = z[mask]
    store = [s for s, k in zip(store, mask) if k]

    model = ParallelNet(
        d=d,
        widths=widths,
        mL=len(z),
        fixed_alpha=True,
        fixed_s=True
    )

    with torch.no_grad():
        for i, p in enumerate(store):

            for l in range(len(model.W)):
                model.W[l][i].copy_(p["W"][l][0])
                model.b[l][i].copy_(p["b"][l][0])

            model.alpha[i] = float(z[i])

    return model


##################################################
# L2 penalty
##################################################

def l2_penalty(model):
    reg = 0.0
    for name, p in model.named_parameters():
        reg += (p ** 2).sum()
    return reg


##################################################
# Nonconvex training
##################################################

def train_parallel_net(X, y, d, widths, mL,
                       beta=0.0, lr=1e-3, iters=2000, σ=sign):

    X_t = torch.tensor(X, dtype=torch.float32)
    y_t = torch.tensor(y, dtype=torch.float32)

    model = ParallelNet(d=d, widths=widths, mL=mL, sigma=σ)

    with torch.no_grad():
        for W in model.W:
            nn.init.xavier_normal_(W)
        for b in model.b:
            b.zero_()
        model.alpha.fill_(1.0)

    opt = torch.optim.Adam(model.parameters(), lr=lr)

    losses = []

    for _ in range(iters):

        opt.zero_grad()

        pred = model(X_t)

        loss = ((pred - y_t) ** 2).mean() + beta * l2_penalty(model)

        loss.backward()
        opt.step()

        losses.append(loss.item())

    return model, losses