from numpy.random import seed
import numpy as np
import math
import matplotlib.pyplot as plt
import pickle
from scipy.special import comb
import pandas as pd
from scipy import stats
import sys
import matplotlib.pyplot as plt
plt.style.use('ggplot')

        
def sigmoid2(x):

    val = 1/(1 + np.exp(-x))
    
    return val
    
def sigmoid(l, r, dim):

    #Sigmoid of l*A*r

    block = np.array([[0,-1],[1,0]])

    A = block_diagonal_matrix(block, int(dim/2))
    print("A = ", A)

    l2 = np.dot(l,A)
    
    #print("l2 = ", l2)
    
    #print("r = ", r)

    x = np.dot(l2,r)

    #print("skew symmetric function g = ", x)

    val = 1/(1 + np.exp(-x))

    #print("sigmoid val = ", val)
    
    return x

def block_diagonal_matrix(block, num_blocks):
# Create a block diagonal matrix with the specified block repeated
    return np.block([[block if i == j else np.zeros_like(block) for j in range(num_blocks)] for i in range(num_blocks)])

def generate_data_block(full_data,j,length,dim):

    f1 = full_data[:,0:dim]
    f2 = full_data[:,dim:2*dim]
    block = full_data[:,2*dim+j]
    rows = length
    feature_length = f1.shape[1]
    cols = (feature_length * 2) + 1
    block_data_full = np.zeros((rows, cols))
    y = np.zeros(rows)
    
    w = 0
    print("f1 shape = ", f1.shape)

    for i in range(length):  
        block_data_full[w,0:feature_length] = f1[i,:]
        block_data_full[w,feature_length:2*feature_length] = f2[i,:]
        block_data_full[w,2*feature_length] = block[i] 
        w = w + 1
               
    return block_data_full

def generate_data_decomposed(t):

    np.random.seed(t)

    dim0 = int(sys.argv[2])

    dim = int(sys.argv[1]) #dim0 #+ 2

    input_data = np.loadtxt("merged_output.txt")
    print("merged output shape = ", int(input_data.shape[1]))
    #dim = int((input_data.shape[1] - 3)/4)
    print("dim = ", dim)
    f1 = input_data[:,0:dim]
    f2 = input_data[:,dim:2*dim]
    phi_1 = input_data[:,2*dim:2*dim+dim0]
    phi_2 = input_data[:,2*dim+dim0:2*(dim+dim0)]
    y = input_data[:,2*(dim+dim0)]
    
    # Deduplicate based on f1 and f2
    f1f2 = input_data[:, 0:2*dim]
    _, unique_idx = np.unique(f1f2, axis=0, return_index=True)
    input_data1 = input_data[np.sort(unique_idx)]

    print("merged output shape after removing duplicate input data1 =", input_data1.shape)
    
    # Deduplicate based on phi1 and phi2
    phi12 = input_data[:, 2*dim:2*(dim+dim0)]
    _, unique_idx = np.unique(phi12, axis=0, return_index=True)
    input_data2 = input_data[np.sort(unique_idx)]

    print("merged output shape after removing duplicate input data2 =", input_data2.shape)


    
    #phi_1 = phi_1.reshape(phi_1.shape[0], 1)
    #phi_2 = phi_2.reshape(phi_2.shape[0], 1)

    length = phi_1.shape[0]
    print("length = ", length)
    
    k = int(phi_1.shape[1])
    print("phi dim = ", k)

    phi = []
    block_func = []

    for i in range(k):
        if i%2 == 0:
            phi.append(phi_1[:,i]*phi_2[:,i+1] - phi_1[:,i+1]*phi_2[:,i])
            print("block number = ", int(i/2))
            i = i+1        
            
    phi = np.transpose(np.asarray(phi))
    print(phi.shape)
    print(f1.shape)
    print(phi_1.shape)
    print(input_data.shape)
    phi_y = np.zeros((length, phi.shape[1]))
    for count in range(length):
        for count_k in range(phi.shape[1]):
            choice = np.random.rand()
            #print("count_k = ", count_k)
            if choice <= sigmoid2(phi[count,count_k]):
               phi_y[count, count_k] = 1
            else:
               phi_y[count, count_k] = 0
    
    block_func = np.transpose(np.asarray(block_func))
    #phi_y = np.array(phi_y).reshape(-1, 1)

    
    full_data = np.concatenate((f1, f2, phi_y), axis = 1)
    print("full_data size = ", full_data.shape)
 
    input_data_items = np.concatenate((input_data[:,0:dim], input_data[:,dim:2*dim]), axis = 0)
 
    blocks = phi.shape[1]
    block_data = np.zeros((int(input_data.shape[0]), int(2*dim)+1, blocks))
    print(block_data.shape)
    print("blocks = ", blocks)

    for j in range(blocks):      

        block_data[:,:,j] = generate_data_block(full_data,j,input_data.shape[0],dim)
        np.savetxt(f"block_data_{j}.txt", block_data[:,:,j])
        
    print(block_data[:,:,0:1])
    print("block data shape = ", block_data.shape)
    
generate_data_decomposed(0)
