from distutils.log import error
import os
from random import triangular
from tkinter import Y
import ipdb as pdb
import numpy as np
import torch
from scipy.stats import ortho_group, wishart
from sklearn.preprocessing import scale
from sklearn.model_selection import train_test_split
import random

# from train_spline import pretrain_spline

root_dir = '.'

def leaky_ReLU_1d(d, negSlope):
    if d > 0:
        return d
    else:
        return d * negSlope

leaky1d = np.vectorize(leaky_ReLU_1d)

def leaky_ReLU(D, negSlope):
    assert negSlope > 0
    return leaky1d(D, negSlope)

def unique_row(matrixA, a, b, matrixB=None ):
    # Get the values in the a-th row
    row = matrixA[a]
    # Count the frequency of each value
    counts = {}
    for value in row:
        if value in counts:
            counts[value] += 1
        else:
            counts[value] = 1
    # Sort the values by frequency
    sorted_values = sorted(counts.keys(), key=lambda x: counts[x], reverse=True)
    # Take the b most frequent values
    b_values = sorted_values[:b]
    # Replace the other values in the row with one of the b_values
    for i in range(len(row)):
        if row[i] not in b_values:
            row[i] = random.choice(b_values)
    # Replace the a-th row in the matrix with the modified row
    if matrixB is not None:
        indexs = get_index(matrixA[a], row)
        print(indexs)
        new_row = matrixB[a][indexs]
        matrixB[a] = new_row + 0
        return matrixA, matrixB
    else:
        return matrixA

def get_index(x,y):
    # initialize an array to store the indices
    y_indices = []

    # loop through each element of y
    for i in range(len(y)):
        # loop through each element of x
        for j in range(len(x)):
            # check if the element of y matches the element of x
            if y[i] == x[j]:
                # store the index of the first occurrence of y[i] in x
                y_indices.append(int(j))
                # break out of the inner loop to avoid storing duplicate indices
                break

    return list(y_indices)


def align_matrix(matrix):
    n, d = matrix.shape
    aligned_matrix = matrix + 0
    aligned_matrix[:, 0] = matrix[:, 0]  # 第一列保持不变

    for i in range(1, d):
        tmp = aligned_matrix[:,i-1] 
        aligned_matrix[:,i] = tmp + 0
        # aligned_matrix[i,i-1] = aligned_matrix[i,i-1]+1
        # aligned_matrix[0:i,i] = aligned_matrix[0:i,i]+np.random.randn(len(aligned_matrix[0:i,i]))
        aligned_matrix[0:i,i] = aligned_matrix[0:i,i]+1
    return aligned_matrix


def align_matrix2(matrix):
    n, d = matrix.shape
    aligned_matrix = matrix + 0
    aligned_matrix[:, 0] = matrix[:, 0]  # 第一列保持不变

    for i in range(1, d):
        tmp = aligned_matrix[:,i-1] 
        aligned_matrix[:,i] = tmp + 0
        # aligned_matrix[i,i-1] = aligned_matrix[i,i-1]+1
        # aligned_matrix[0:i,i] = aligned_matrix[0:i,i]+np.random.randn(len(aligned_matrix[0:i,i]))
        aligned_matrix[i-1,i] = aligned_matrix[i-1,i]+1
    return aligned_matrix

def gen_da_data_ortho(
    Nsegment, 
    Ncomp=4,
    Ncomp_s=2,
    Nlayer=3,  
    var_range_l=0.01,
    var_range_r=3,
    mean_range_l=0,
    mean_range_r=3,
    NsegmentObs_train=7500,
    NsegmentObs_test=1000,
    Nobs_test=4096,
    varyMean=True, 
    mixtures=True,
    seed=1,
    n_modes_range_l=2,
    n_modes_range_r=6,
    p_domains_range_l=1,
    p_domains_range_r=2,
    linear_mixing_first=False,
    mixed = False,
    source = 'Gaussian',
    save_all_datasets=False,
    triangle = False,
    iid_domain = False,
    ):
    """
    generate multivariate data based on the non-stationary non-linear ICA model of Hyvarinen & Morioka (2016)
    we generate mixing matrices using random orthonormal matrices
    INPUT
        - Ncomp: number of components (i.e., dimensionality of the data)
        - Nlayer: number of non-linear layers!
        - Nsegment: number of data segments to generate
        - NsegmentObs: number of observations per segment
        - source: either Laplace or Gaussian, denoting distribution for latent sources
        - NonLin: linearity employed in non-linear mixing. Can be one of "leaky" = leakyReLU or "sigmoid"=sigmoid
          Specifically for leaky activation we also have:
            - negSlope: slope for x < 0 in leaky ReLU
            - Niter4condThresh: number of random matricies to generate to ensure well conditioned
    OUTPUT:
      - output is a dictionary with the following values:
        - sources: original non-stationary source
        - obs: mixed sources
        - labels: segment labels (indicating the non stationarity in the data)
    """

    negSlope = 0.2
    NonLin = 'leaky'
    # np.random.seed(seed)
    randomstate = np.random.RandomState(seed)

    # generate non-stationary data:
    train_size = NsegmentObs_train * Nsegment
    assert Nobs_test == 0 or NsegmentObs_test == 0
    if Nobs_test > 0:
        NsegmentObs_test = int(Nobs_test // Nsegment)
    test_size = NsegmentObs_test * Nsegment
    NsegmentObs_total = NsegmentObs_train + NsegmentObs_test
    Nobs = train_size + test_size  # total number of observations
    labels = np.array([0] * Nobs)  # labels for each observation (populate below)
    Ncomp_c = Ncomp - Ncomp_s

    # generate data, which we will then modulate in a non-stationary manner:
    if source == 'Laplace':
        dat = randomstate.laplace(0, 1, (Nobs, Ncomp))
        dat = scale(dat)  # set to zero mean and unit variance
    elif source == 'Gaussian':
        dat = randomstate.normal(0, 1, (Nobs, Ncomp))
        dat = scale(dat)
    else:
        raise Exception("wrong source distribution")


    # get modulation parameters
    modMat = randomstate.uniform(var_range_l, var_range_r, (Ncomp_s, Nsegment))
    if varyMean:
        meanMat = randomstate.uniform(mean_range_l, mean_range_r, (Ncomp_s, Nsegment))
    else:
        meanMat = np.zeros((Ncomp_s, Nsegment))

    print('meanMat is', meanMat)
    print('modMat is', modMat)
    if triangle:
        print("number of sources are", Ncomp_s)
        print('number of domains are', Nsegment)
        # meanMat = align_matrix2(meanMat)
        # modMat = align_matrix2(modMat)
        meanMat = align_matrix(meanMat)
        modMat = align_matrix(modMat)

        print('meanMat is', meanMat)
        print('modMat is', modMat)

    elif iid_domain is False:
        if  varyMean:
            modMat, meanMat = unique_row(modMat, 1, 3, meanMat)
        else:
            modMat = unique_row(modMat, 1, 3)
            meanMat = np.zeros((Ncomp_s, Nsegment))


    # now we adjust the variance within each segment in a non-stationary manner
    for seg in range(Nsegment):
        segID = range(NsegmentObs_total * seg, NsegmentObs_total * (seg + 1))
        dat[segID, -Ncomp_s:] = np.multiply(dat[segID, -Ncomp_s:], modMat[:, seg])
        dat[segID, -Ncomp_s:] = np.add(dat[segID, -Ncomp_s:], meanMat[:, seg])
        labels[segID] = seg

    # now we are ready to apply the non-linear mixtures:
    mixedDat = np.copy(dat)


    # generate mixing matrices:
    if linear_mixing_first:
        A = ortho_group.rvs(Ncomp, random_state=randomstate)
        mixedDat = np.dot(mixedDat, A)
    for l in range(Nlayer - 1):
        # we first apply non-linear function, then causal matrix!
        if NonLin == 'leaky':
            mixedDat = leaky_ReLU(mixedDat, negSlope)
        elif NonLin == 'sigmoid':
            mixedDat = sigmoidAct(mixedDat)

        # generate causal matrix first:
        A = ortho_group.rvs(Ncomp, random_state=randomstate)  # generateUniformMat( Ncomp, condThresh )
        # apply mixing:
        mixedDat = np.dot(mixedDat, A)

    # stratified split
    x_train, x_test, z_train, z_test, u_train, u_test = train_test_split(
        mixedDat, dat, labels, train_size=train_size, test_size=test_size, random_state=randomstate, shuffle = True, stratify=labels
    )

    train_tmp = np.hstack((x_train, z_train, u_train.reshape(len(u_train), -1)))
    test_tmp = np.hstack((x_test, z_test, u_test.reshape(len(u_test), -1)))

    train_ord = train_tmp[train_tmp[:, train_tmp.shape[1]-1].argsort()]
    test_ord = test_tmp[test_tmp[:, test_tmp.shape[1]-1].argsort()]

    x_train, z_train, u_train = train_ord[:,0:Ncomp], train_ord[:,Ncomp:2*Ncomp], train_ord[:,2*Ncomp]
    x_test, z_test, u_test = test_ord[:,0:Ncomp], test_ord[:,Ncomp:2*Ncomp], test_ord[:,2*Ncomp]

    if save_all_datasets is True:
        all_datasets = {}
        for domID in range(Nsegment):
            train_indices = u_train<=domID
            test_indices = u_test<=domID
            all_datasets[domID+1] = {
                "train": {"y": z_train[train_indices], "x": x_train[train_indices], "c": u_train[train_indices]},
                "test": {"y": z_test[test_indices], "x": x_test[test_indices], "c": u_test[test_indices]}
            }

        torch.save(all_datasets, f"./data/all_datasets_{Nsegment}_seed_{seed}_domain_validation_size_{NsegmentObs_test}_mean_{mean_range_r}_var_{var_range_r}_n_components_{Ncomp}.pth")
        
        pdb.set_trace()

    return {"y": z_train, "x": x_train, "c": u_train}, {"y": z_test, "x": x_test, "c": u_test}  
