import numpy as np
from scipy.linalg import svd
from scipy.stats import ortho_group, unitary_group
from itertools import combinations
import math
import matplotlib.pyplot as plt

MAX_SAMPLE = 1000 
# The actual value used was 10000

### Helper functions ###

def random_Grassmannian_vector(n, k, seed=42):
    """
    n, k : int, int
    Note : n > k
    """
    np.random.seed(seed)
    vector = np.zeros(math.comb(n, k))

    m = ortho_group.rvs(dim=n)
    
    sn = +1
    # sn = np.sign(np.linalg.det(m[:k,:k]))
    # sn = np.random.choice([+1,-1])

    # iterate over all size k subsets 
    for i, c in enumerate(combinations(range(n), k)):
        vector[i] = sn*np.linalg.det(m[c, :k])

    return vector

# Example usage
# vec = random_Grassmannian_vector(k=2,n=3)
# print('Grassmanian : ', vec.shape)

def random_Flag_vector(n,k1,k2,seed = 42):
    '''
    n, k1, k2 : int, int, int
    Note: n >= k1, n >= k2, n > k1+k2
    '''
    np.random.seed(seed)
    vector = np.zeros((math.comb(n, k1),math.comb(n,k2)))

    m = ortho_group.rvs(dim=n)

    sn = +1
    # sn = np.sign(np.linalg.det(m[:k,:k]))
    # sn = np.random.choice([+1,-1])

    # iterate over all size k subsets 
    for i, c1 in enumerate(combinations(range(n), k1)):
        for j, c2 in enumerate(combinations(range(n),k2)):
            vector[i,j] = sn*np.linalg.det(m[c1, :k1])*np.linalg.det(m[c2, :k2])
    
    vector = np.reshape(vector,-1)

    return vector

# Example usage
# vec = random_Flag_vector(n=4,k1=2,k2=3)
# print('Flag : ', vec.shape)
# print('exp : ', math.comb(4,2)*math.comb(4,3))

def kron(*args):
    if len(args) == 1:
        return args[0]
    return np.kron(args[0], kron(*args[1:]))


def random_embedding(n,seed=42):
    """
    n : int
    """
    np.random.seed(seed)
    amplitudes = np.random.random(n**2) * 2 - 1
    amplitudes = amplitudes / np.linalg.norm(amplitudes)

    unit_vectors = np.identity(n)

    state = np.zeros(n**4)

    for a in range(n):
        for b in range(n):

            s = np.zeros(n**4)
            for c in range(n):
                s += kron(
                    unit_vectors[c],
                    unit_vectors[(c + a) % n],
                    unit_vectors[(c + b) % n],
                    unit_vectors[(c + a + b) % n],
                )

            state += amplitudes[n * a + b] / np.sqrt(n) * s

    state = state / np.sqrt(n)
    return state.reshape((n**4,))


def random_unitary_mod_pauli(embedding, seed = 42):
    """
    embedding : np.array,
    """
    np.random.seed(seed)
    n = int(round(len(embedding) ** (1 / 4)))
    u = unitary_group.rvs(dim=n)
    rho = kron(u, u, u.conj(), u.conj())
    state = rho @ embedding
    
    return np.concatenate((state.real, state.imag))

def compute_hofstadter_butterfly(q_max=100):
    phi_vals = []
    energy_vals = []

    # Assumes ky = 0.
    # H   = -t \sum_{i=1}^{N} c_{x+1,k_y}^\dagger c_{x,k_y} + h.c. 
    #       + cos(2*pi*n_\phi x +  2 pi k_y/N)c_{x,k_y}^\dagger c_{x,k_y}
    # Source : https://www.xif.fr/uploads/td-topo.pdf (Eq on slide 10/21)

    for q in range(1, q_max + 1):
        for p in range(q):
            phi = p / q
            H = np.zeros((q, q), dtype=np.complex128)
            for i in range(q):
                H[i, i] = 2 * np.cos(2 * np.pi * phi * i)
                H[i, (i + 1) % q] = 1
                H[(i + 1) % q, i] = 1
            eigenvalues = np.linalg.eigvalsh(H)
            phi_vals.extend([phi] * len(eigenvalues))
            energy_vals.extend(eigenvalues)
    
    phi_vals = np.array(phi_vals)
    energy_vals = np.array(energy_vals)
    return phi_vals, energy_vals

def plot_hofstadter_butterfly(phi_vals, energy_vals):
    plt.figure(figsize=(10, 8))
    plt.scatter(phi_vals, energy_vals, s=0.1, color='black')
    plt.xlabel("Magnetic flux per plaquette (ϕ/ϕ₀)")
    plt.ylabel("Energy")
    plt.title("Hofstadter's Butterfly")
    plt.grid(False)
    plt.tight_layout()
    plt.show()

def sample_from_butterfly(phi_vals, energy_vals, N=1000):
    indices = np.random.choice(len(phi_vals), size=N, replace=False)
    samples_phi = phi_vals[indices]
    samples_energy = energy_vals[indices]    
    
    return np.stack([samples_phi, samples_energy], axis=1)

#### Helper functions end #####

def data_gen_lin_nsp(N=MAX_SAMPLE,params=(200,100,42)):
    '''
    params : (d1: int, d2: int, seed: int)
    '''
    d1,d2,seed = params
    np.random.seed(seed)
    
    U = np.random.randn(d1, d2)
    V = np.random.randn(d2, d1)
    A = np.dot(U, V)
    
    U, s, Vt = svd(A)
    null_space_basis = Vt[d2:].T
    
    null_space_dim = d1 - d2
    coefficients = np.random.randn(N, null_space_dim)
    null_space_vectors = np.dot(coefficients, null_space_basis.T)

    info_dict = {'samples':null_space_vectors,
                 'params': params,
                 'matrixA': A}
    
    return info_dict

# Example usage
# info_dict = data_gen_nonlin_nsp(N=100, params = (20,10,0))
# samples = info_dict['samples']
# print('Lin nullspaces : ', samples.shape)


def data_gen_gman_proj(N=MAX_SAMPLE, params=(6,3,42)):
    '''
    params : (n: int, k: int, seed : int)
    '''
    d1,d2,seed= params
    np.random.seed(seed)
    
    projectors = np.zeros((N, d1, d1))

    D = np.zeros((d1, d1))
    np.fill_diagonal(D[:d2, :d2], 1)

    list_of_matrices = []
    
    for i in range(N):
        random_matrix = np.random.randn(d1, d1)
        Q, _ = np.linalg.qr(random_matrix)
        list_of_matrices.append(Q)
        
        
        P = Q @ D @ Q.T
        
        projectors[i] = P
    
    info_dict = {'samples':projectors,
                 'params': params,
                 'listQs': np.array(list_of_matrices),
                 }
    
    return info_dict

# Example usage
# info_dict = data_gen_gman_proj(N=10)
# samples = info_dict['samples']
# print('Gman projector : ', samples.shape)
# print(np.mean(np.abs(samples) ) )

def data_gen_gman_vec(N=MAX_SAMPLE, params=(6,3,42)):
    '''
    params : (n: int, k: int, seed: int)
    '''
    d1,d2,seed=params
    np.random.seed(seed)
    dim = math.comb(d1,d2)
    vectors = np.zeros((N,dim))
    for i in range(N):
        vectors[i,:] = random_Grassmannian_vector(n=d1,k=d2,seed=seed+i)
    
    info_dict = {'samples':vectors,
                 'params': params,
                 }    
    
    return info_dict

# Example usage
# info_dict = data_gen_gman_vec(N=10)
# samples = info_dict['samples']
# print('Gman vector : ', samples.shape)
# print(np.mean(np.abs(samples) ) )

def data_gen_steifel_proj(N=MAX_SAMPLE, params=(6,3,42)):
    '''
    params : (n: int, k: int, seed: int)
    '''
    d1,d2,seed= params
    np.random.seed(seed)
    
    projectors = np.zeros((N, d1, d2))

    for i in range(N):
        random_matrix = np.random.randn(d1, d1)
        Q, _ = np.linalg.qr(random_matrix)
        
        projectors[i] = Q[:,:d2]


    info_dict = {'samples':projectors,
                 'params': params,
                 }

    return info_dict

# Example usage
# info_dict = data_gen_steifel_proj(N=10)
# samples = info_dict['samples']
# print('Steifel projector : ', samples.shape)
# print(np.mean(np.abs(samples) ) )

def data_gen_steifel_vec1(N=MAX_SAMPLE, params = (10,5,42)):
    '''
    params : (n: int, k: int, seed: int)
    '''
    d1,d2,seed=params
    np.random.seed(seed)

    dim = math.comb(d1,d2) 
    vectors = np.zeros((N,dim+ d2**2))
    for i in range(N):
        vectors[i,:dim] = random_Grassmannian_vector(n=d1,k=d2,seed=seed+i)
        m = np.reshape(ortho_group.rvs(dim=d2), -1)
        vectors[i,dim:] = m

    info_dict = {'samples':vectors,
                 'params': params,
                 }

    return info_dict

# Example usage
# info_dict = data_gen_steifel_vec1(N=10)
# samples = info_dict['samples']
# print('Steifel vector 1 : ', samples.shape)
# print('Expected : ', math.comb(10,5)+5**2)
# print(np.mean(np.abs(samples) ) )

def data_gen_flag_vec(N=MAX_SAMPLE, params=(4,3,2,42)):
    '''
    params : (n: int, k1: int, k2: int, seed: int)
    '''
    d1,d2,d3,seed=params
    np.random.seed(seed)
    dim = math.comb(d1,d2)*math.comb(d1,d3)
    vectors = np.zeros((N,dim))
    for i in range(N):
        vectors[i,:] = random_Flag_vector(n=d1,k1=d2,k2=d3,seed=seed+i)
    
    info_dict = {'samples':vectors,
                 'params': params,
                 }    
    
    return info_dict

# Example usage
# info_dict = data_gen_flag_vec(N=10)
# samples = info_dict['samples']
# print('Flag vector : ', samples.shape)
# print('Expected : ', math.comb(4,3)*math.comb(4,2))
# print(np.mean(np.abs(samples) ) )

def data_gen_sun_mod_pauli(N=MAX_SAMPLE, params=(3,42)):
    '''
    params: (n: int)
    '''
    d1,seed=params
    np.random.seed(seed)
    dim = 2*d1**4
    vectors = np.zeros((N,dim))
    for i in range(N):
        emb = random_embedding(d1,seed=seed+i)
        vectors[i,:] = random_unitary_mod_pauli(emb, seed =seed+i)
    
    info_dict = {'samples':vectors,
                 'params': params,
                 }    
    
    return info_dict

# Example usage
# info_dict = data_gen_sun_mod_pauli(N=10)
# samples = info_dict['samples']
# print('Random unitary : ', samples.shape)
# print('Expected : ', 2*(3**4))
# print(np.mean(np.abs(samples) ) )
    
def data_gen_fractal_hofstadter(N=MAX_SAMPLE, params = (2,42), sample=False,verbose=False):
    # AD: no seed necessary
    d,seed= params
    np.random.seed(seed)
    phi_vals, energy_vals = compute_hofstadter_butterfly(q_max=100)
    if verbose:
        plot_hofstadter_butterfly(phi_vals, energy_vals)
    if sample:
        samples = sample_from_butterfly(phi_vals, energy_vals, N=N, sigma=0.0)
    else:
        ids = np.random.permutation(len(phi_vals))
        spr = phi_vals[ids]
        ser = energy_vals[ids]
        samples = np.stack([spr, ser], axis=1)
    
    info_dict = {'samples': samples,
                 'params' : params,
                 }

    return info_dict

# Example usage
# info_dict = data_gen_fractal_hofstadter(N=10)
# samples = info_dict['samples']
# print('Fractals : ', samples.shape)
# print('Expected : ', 2)
# print(np.mean(np.abs(samples) ) )

def hSphere_isometry(di=3, da=6, N=MAX_SAMPLE, seed= 42):
    np.random.seed(seed)
    Wr = np.random.randn(di+1,da)
    samples = np.random.randn(N,di+1)
    samples = samples/np.linalg.norm(samples, axis=1, keepdims=True)
    samples = samples @ Wr

    return samples

def hSphere(di=3,  N=MAX_SAMPLE, seed= 42):
    np.random.seed(seed)
    samples = np.random.randn(N,di+1)
    samples = samples/np.linalg.norm(samples, axis=1, keepdims=True)
    
    return samples

def isoNormal(di=3, da=6, N=MAX_SAMPLE, seed= 210):
    np.random.seed(seed)
    Wr = np.random.randn(di,da)
    samples = np.random.randn(N,di)
    samples = samples @ Wr 
    
    return samples

def mBeta(di=3, da=6, N=MAX_SAMPLE, seed= 840):
    np.random.seed(seed)
    Wr = np.random.randn(di,da)
    samples = np.random.rand(N,di)
    samples = np.sin(np.cos(samples))
    samples = samples @ Wr 
    
    return samples