import numpy as np
import sys
from numpy import random
import math
from scipy import stats
from scipy.special import comb
from scipy.linalg import block_diag
import matplotlib.pyplot as plt
import torch
from sklearn.preprocessing import MinMaxScaler,StandardScaler
from sklearn.preprocessing import PolynomialFeatures
from itertools import combinations, combinations_with_replacement
import torch.nn as nn
from torch.autograd.functional import jacobian
import seaborn as sns
from captum.attr import IntegratedGradients
import torch.nn.functional as F
#from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import os


class WrapperModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        out = self.model(x)
        if out.dim() == 1:
            out = out.unsqueeze(0)  # Ensures batch dimension
        return out

# Step 1: Random feature mapping network
class RandomFeatureMapper(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, feature_dim, seed=42):
        super().__init__()
        layers = []
        self.active_dim = input_dim #- 2
        current_dim = self.active_dim

        g = torch.Generator().manual_seed(seed)

        '''
        for h in hidden_dims:
            layers.append(nn.Linear(current_dim, h))
            layers.append(nn.ReLU())
            current_dim = h
        layers.append(nn.Linear(current_dim, feature_dim))
        '''

        for _ in range(num_layers):
            linear = nn.Linear(current_dim, hidden_dim)
            nn.init.kaiming_normal_(linear.weight, nonlinearity='relu', generator=g)  # Kaiming init
            layers.append(linear)
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm(hidden_dim))  # Keeps activations stable
            current_dim = hidden_dim

        # Final projection to feature_dim (without ReLU)
        final_linear = nn.Linear(current_dim, feature_dim)
        nn.init.kaiming_normal_(final_linear.weight, nonlinearity='linear', generator=g)
        layers.append(final_linear)
        self.mapper = nn.Sequential(*layers)
        
        # Freeze parameters (no training)
        for param in self.mapper.parameters():
            param.requires_grad = False

    def forward(self, x):
        x_used = x[:, :self.active_dim]  # only use the active features
        return self.mapper(x_used)
        
        
def skew_symmetric_fn(phi1, phi2):
    return torch.sum(phi1[:, 0] * phi2[:, 1] - phi1[:, 1] * phi2[:, 0], dim=0, keepdim=True)
    
# Step 3: Wrap everything into a combined model that takes concatenated input
class CombinedModel(nn.Module):
    def __init__(self, mapper):
        super().__init__()
        self.mapper = mapper

    def forward(self, x):
        x1, x2 = x[:, :4], x[:, 4:]
        phi1 = self.mapper(x1)
        phi2 = self.mapper(x2)
        return skew_symmetric_fn(phi1, phi2)

   
def swap_terms(expanded_features, idx1, idx2):
    #print("expanded_features shape = ", expanded_features.shape, idx1, idx2, expanded_features[idx2,:], expanded_features[idx1,:])
    expanded_features[[idx1, idx2]] = expanded_features[[idx2, idx1]]
    '''
    temp = expanded_features[idx1, :].copy()
    expanded_features[idx1, :] = expanded_features[idx2, :]
    expanded_features[idx2, :] = temp
    '''
    #print("expanded_features shape after swap = ", expanded_features.shape, idx1, idx2, expanded_features[idx2,:], expanded_features[idx1,:])
    
    return expanded_features

def quadratic_kernel_expansion(X):

    X = np.asarray(X)
    #print("X = ", X)  
    
    # Original features
    expanded_features = list(X)
    #print("expanded_features = ", expanded_features)
    
    # Squared terms
    expanded_features.extend(X ** 2)
    #print("expanded_features = ", expanded_features)

    # Pairwise interaction terms
    for i, j in combinations(range(len(X)), 2):
        expanded_features.append(X[i] * X[j])
    
    #print("expanded_features = ", expanded_features)
    return np.array(expanded_features)

class QuadraticModel2(torch.nn.Module):
      def __init__(self, blocks):
        super().__init__()
        self.blocks = blocks      
      
      def forward(self, x):
        print("x = ", x.shape)
        x = x.view(-1)  # ensure 1D if needed
        # Compute quadratic kernel features directly in torch
        expanded_features = [x]
        expanded_features.append(x ** 2)

        '''
        # Pairwise interaction terms
        interaction_terms = []
        for i in range(len(x)):
            for j in range(i+1, len(x)):
                interaction_terms.append(x[i] * x[j])
        '''
        # Pairwise interaction terms
        for i, j in combinations(range(len(x)), 2):
            expanded_features.append(x[i] * x[j])
        # Concatenate all
        
        #features = expanded_features + interaction_terms
        #print("features = ", features, len(features))
        #features = [f.unsqueeze(1) if f.dim() == 1 else f for f in features]
        quad_output = torch.tensor(expanded_features)
        print("quad_output shape = ", quad_output.shape) 
        
        #quad_output = torch.stack(expanded_features + interaction_terms)


        # Select required block features (avoid NumPy)
        selected_blocks = []
        print("self.blocks, quad_output = ", self.blocks, quad_output, quad_output.shape)
        for i in range(len(self.blocks) - 1):
            selected_blocks.append(quad_output[2 * self.blocks[i]])
            selected_blocks.append(quad_output[2 * self.blocks[i] + 1])

        return torch.stack(selected_blocks)
        
class QuadraticModel(torch.nn.Module):
    def __init__(self, blocks):
        super().__init__()
        self.blocks = blocks  # list of block indices (e.g., [0, 1])

    def forward(self, x):
        # Ensure input is 2D: [batch_size, num_features]
        if x.dim() == 1:
            x = x.unsqueeze(0)

        batch_size, num_features = x.shape
        features = []

        # Original features
        features.append(x)  # shape: [B, D]

        # Squared features
        features.append(x ** 2)  # shape: [B, D]

        # Pairwise interaction terms
        interaction_terms = []
        for i, j in combinations(range(num_features), 2):
            interaction = (x[:, i] * x[:, j]).unsqueeze(1)  # shape: [B, 1]
            interaction_terms.append(interaction)

        if interaction_terms:
            pairwise = torch.cat(interaction_terms, dim=1)  # shape: [B, num_interactions]
            features.append(pairwise)

        # Concatenate all features along dim=1
        quad_output = torch.cat(features, dim=1)  # shape: [B, total_features]
        #print("quad_output shape =", quad_output.shape)

        # Block selection from original and squared only
        # These are at indices: 2 * block_idx and 2 * block_idx + 1
        selected = []
        for idx in self.blocks:
            selected.append(quad_output[:, 2 * idx].unsqueeze(1))       # original
            selected.append(quad_output[:, 2 * idx + 1].unsqueeze(1))   # squared

        selected_output = torch.cat(selected, dim=1)  # shape: [B, 2 * len(blocks)]
        #print("selected_output shape =", selected_output.shape)
        return selected_output

      


def cubic_kernel_expansion(X):

    X = np.asarray(X)
    print("X = ", X)
    
    # Original features
    expanded_features = list(X)
    print("expanded_features = ", expanded_features)
    
    # Squared terms
    expanded_features.extend(X ** 2)
    print("expanded_features = ", expanded_features)
    
    # Pairwise interaction terms
    for i, j in combinations(range(len(X)), 2):
        expanded_features.append(X[i] * X[j])
    
    print("expanded_features = ", expanded_features)
    
    # Pure cubic terms
    expanded_features.extend(X ** 3)
    print("expanded_features = ", expanded_features)
    
    # Quadratic-cubic mixed terms (x_i^2 * x_j and x_i * x_j^2), ensuring sorted order
    for i, j in combinations(range(len(X)), 2):
        expanded_features.append(X[min(i, j)] ** 2 * X[max(i, j)])
        expanded_features.append(X[min(i, j)] * X[max(i, j)] ** 2)
    
    # Fully mixed cubic terms (x_i * x_j * x_k, all different), ensuring sorted order
    for i, j, k in combinations(range(len(X)), 3):
        indices = sorted([i, j, k])
        expanded_features.append(X[indices[0]] * X[indices[1]] * X[indices[2]])
    
    print("expanded_features = ", expanded_features)
    return np.array(expanded_features)


def sinusoidal_mapping(x):
    # x: shape (..., 4)
    # Output: (..., 6)
    x = np.asarray(x)
    if x.ndim == 1:
        x = x.reshape(1, -1)
    
    mapped = np.concatenate([
        np.sin(x[:, 0:1]),         # sin(x₁)
        np.cos(x[:, 1:2]),         # cos(x₂)
        np.sin(2 * x[:, 2:3]),     # sin(2x₃)
        np.cos(2 * x[:, 3:4]),     # cos(2x₄)
        np.sin(x[:, 0:1] + x[:, 1:2]),  # sin(x₁ + x₂)
        np.cos(x[:, 2:3] - x[:, 3:4])   # cos(x₃ - x₄)
    ], axis=1)
    return mapped

class SinusoidalMappingModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Ensure input is 2D (batch_size, 4)
        if x.dim() == 1:
            x = x.unsqueeze(0)

        x1 = x[:, 0:1]
        x2 = x[:, 1:2]
        x3 = x[:, 2:3]
        x4 = x[:, 3:4]

        mapped = torch.cat([
            torch.sin(x1),                       # sin(x₁)
            torch.cos(x2),                       # cos(x₂)
            torch.sin(2 * x3),                   # sin(2x₃)
            torch.cos(2 * x4),                   # cos(2x₄)
            torch.sin(x1 + x2),                  # sin(x₁ + x₂)
            torch.cos(x3 - x4)                   # cos(x₃ - x₄)
        ], dim=1)

        return mapped

def hybrid_mapping(x):
    x = np.asarray(x)
    if x.ndim == 1:
        x = x.reshape(1, -1)
    mapped = np.concatenate([
        x[:, 0:1] ** 2,
        x[:, 1:2] * x[:, 2:3],
        np.sin(x[:, 3:4]),
        np.cos(x[:, 0:1]),
        np.sin(x[:, 1:2] + x[:, 2:3]),
        x[:, 3:4] ** 2
    ], axis=1)
    return mapped

class HybridMappingModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Ensure input is 2D
        if x.dim() == 1:
            x = x.unsqueeze(0)

        # Extract individual features (assuming x has at least 4 columns)
        x0 = x[:, 0:1]  # Feature 0
        x1 = x[:, 1:2]  # Feature 1
        x2 = x[:, 2:3]  # Feature 2
        x3 = x[:, 3:4]  # Feature 3

        # Construct hybrid features
        mapped = torch.cat([
            x0 ** 2,                        # square of feature 0
            x1 * x2,                        # interaction between feature 1 and 2
            torch.sin(x3),                 # sine of feature 3
            torch.cos(x0),                 # cosine of feature 0
            torch.sin(x1 + x2),            # sine of feature 1 + 2
            x3 ** 2                         # square of feature 3
        ], dim=1)

        return mapped



def generate_items_uniform(num_items, dim, t):

    np.random.seed(t)
    
    #items = np.random.uniform(low = 0, high = 1, size = (dim,num_items))

    #items = np.random.beta(1, 100, size = (dim,num_items))
    items = np.random.exponential(scale = 1.0, size = (dim,num_items))

    #items = items/np.linalg.norm(items,2,axis = 0)

    return items

def generate_items_normal(num_items, dim, std_deviation,t):

    np.random.seed(t)
        
    items = np.random.normal(0, scale = std_deviation, size = (dim,num_items))

    items = items/np.linalg.norm(items,2,axis = 0)
       
    return items
    
def generate_items_unit_circle(num_items, t):

    np.random.seed(t)
    
    items_1 = np.random.uniform(low = -1, high = 1, size = (1,num_items))
    items_2 = np.sqrt(1 - items_1**2)
    
    print("items_1 = ", items_1)
    print("items_2 = ", items_2)
    
    items = np.append(items_1, items_2, 0)

    #items = items/np.linalg.norm(items,2,axis = 0)

    return items
    
def generate_random_points_on_circle(num_points, radius, t):

    np.random.seed(t)
    
    angles = np.random.uniform(0, 2 * np.pi, size = (1,num_points))
    
    x = radius * np.cos(angles)
    y = radius * np.sin(angles)
    
    print("x = ", x)
    print("y = ", y)
    
    items = np.append(x, y, 0)  
    
    return items
    
def generate_random_covariance_matrix(dim):

    # Generate a random matrix with values from a standard normal distribution
    random_matrix = np.random.randn(dim, dim)
    
    # Create a symmetric matrix by taking the dot product of the random matrix with its transpose
    covariance_matrix = np.dot(random_matrix, random_matrix.T)
    
    # Scale the matrix to make it positive definite (all eigenvalues are positive)
    covariance_matrix = covariance_matrix + dim * np.eye(dim)
    
    return covariance_matrix


    
def generate_random_points_covariance_matrix(num_points, dim, t):

    np.random.seed(t)
    
    #mean = np.random.rand(dim)
    
    #print("mean = ", mean)
    
    # Generate a random covariance matrix of the given dimension
    #cov_matrix = generate_random_covariance_matrix(dim)
    
    #print("cov_matrix = ", cov_matrix)
    #cov_matrix = np.zeros((dim,dim))
    '''
    mean = np.zeros(dim+2)
    #cov_vector = [1.9, 1.5, 1, 1]
    #cov_vector = [1, 1, 40, 1]
    cov_vector = 1 * np.ones(dim+2)
    cov_matrix = cov_vector*np.eye(dim+2,dim+2)
    '''
    
    mean = np.zeros(dim)
    #cov_vector = [1.9, 1.5, 1, 1]
    #cov_vector = [1, 1, 40, 1]
    cov_vector = 1 * np.ones(dim)
    cov_matrix = cov_vector*np.eye(dim,dim)
    
    '''
    for i in range(dim):
        #cov_matrix[i][i] = cov_vector[i]
        #if i > 0:
        for j in range(i+1,dim):
           cov_matrix[i][j] = 0.7
           cov_matrix[j][i] = 0.7
    '''
    '''
    cov_matrix[0][1] = 80
    cov_matrix[1][0] = 80
    cov_matrix[1][2] = 90
    cov_matrix[2][1] = 90
    cov_matrix[2][3] = 90
    cov_matrix[3][2] = 90
    cov_matrix[8][9] = 90
    cov_matrix[9][8] = 90
    '''
    
    print("cov_matrix = ", cov_matrix)

    # Generate random data points from a multivariate normal distribution
    #num_points = 100
    data = np.random.multivariate_normal(mean, cov_matrix, num_points)
    #s_scaler = StandardScaler()
    #data = s_scaler.fit_transform(data)
    #data = data/np.linalg.norm(data, axis = 0, keepdims = True)
    #data = generate_items_uniform(num_points, dim, t)
    #data = np.loadtxt("original_features.txt")
    #data[:,3] = data[:,2] #+ np.random.normal(size = num_points)
    '''
    f2 = 2
    f3 = 3
    data[:,[f2,f3]] = data[:,[f3,f2]]
    '''

    np.savetxt("original_features.txt", data)
    print("data shape = ", data.shape)
    #data = np.transpose(data)
    '''
    fig = plt.gcf()
    plt.scatter(data[0,:],data[1,:])
    
    fig.savefig('scatter_plot.png', bbox_inches='tight')
    plt.clf()
    #plt.scatter(data[:,0],data[:,1])
    #plt.show()
    ''' 
    #data = np.transpose(data)
    #np.savetxt("original_features.txt", data)
    #print("correlated data = ", data)
    
    return data

def generate_matches(num_items, plays, dim, t):

        #U = generate_items_uniform(num_items, dim, t)
        #U = generate_items_unit_circle(num_items, t)
        U = generate_random_points_covariance_matrix(num_items, dim, t)
        #U = np.loadtxt("original_features.txt")
        U = np.transpose(U)
        #U = generate_random_points_on_circle(num_items, 1, t)
        
        print("U = ", U.shape)

        total_pairs = int(comb(num_items,2))

        pairs_1 = np.zeros((2 , total_pairs))
        
        k = 0
        for i in range(num_items - 1):
            for j in range(i+1, num_items):
                pairs_1[0][k] = i
                pairs_1[1][k] = j
                k = k + 1

        #np.random.shuffle(np.transpose(pairs_1))

        #play_pairs = np.random.randint(0, total_pairs, plays)

        #matches = pairs_1[:, play_pairs]

        print("total_pairs, plays = ", total_pairs, plays)

        #print("matches = ", matches)

        

        return U, pairs_1, total_pairs

        '''
        
        train_data1 = np.zeros((int(m*l),(num_items+1)))
        train_data2 = np.zeros((int(m*l),(num_items)))
        
        test_data1 = np.zeros((int(int(comb(num_items,2))-m),(num_items+1)))
        test_data2 = np.zeros((int(int(comb(num_items,2))-m),(num_items)))

        
        pairs_train = np.zeros((2 , m))
        pairs_test = np.zeros((2 , int(comb(num_items,2)) - m))

        data_size = np.shape(train_data1)[0]
        train_size = int(data_size * 0.75)
        validation_size = data_size - train_size

        
        for i in range(m):
            pairs_train[0][i] = pairs_1[0][i]
            pairs_train[1][i] = pairs_1[1][i]

        for i in range(m,int(comb(num_items,2))):
            pairs_test[0][i-m] = pairs_1[0][i]
            pairs_test[1][i-m] = pairs_1[1][i]
            
        pairs_train0 = pairs_train[0][:].astype(int)
        pairs_train1 = pairs_train[1][:].astype(int)
        pairs_test0 = pairs_test[0][:].astype(int)
        pairs_test1 = pairs_test[1][:].astype(int)
        '''
def block_diagonal_matrix(block, num_blocks):
    # Create a block diagonal matrix with the specified block repeated
    return np.block([[block if i == j else np.zeros_like(block) for j in range(num_blocks)] for i in range(num_blocks)])

        
def sigmoid(l, r, dim):

    #Sigmoid of l*A*r

    t = np.zeros((2,2))

    t[0, 1] = 1
    t[1, 0] = -1
    
    #k = dim #no. of repeats
    
    #A = np.kron(np.eye(k,dtype=int),t)
    
    #A = block_diag(t)
    #A = block_diag(t, t)
    block = np.array([[0,1],[-1,0]])

    A = block_diagonal_matrix(block, int(dim/2))
    #print("A = ", A)

    l2 = np.dot(A,l)
    
    #print("l2 = ", l2)
    
    #print("r = ", r)

    x = np.dot(l2,r)

    #print("skew symmetric function g = ", x)

    val = 1/(1 + np.exp(-x))

    #print("val = ", val)
    
    #return val
    return x, val

def sigmoid2(x):

    val = 1/(1 + np.exp(-x))

    return val

def generate_data(num_items, num_pairs, plays, dim, t):

    #y = np.zeros(num_pairs)
    y = np.zeros(plays)

    U, pairs, total_pairs = generate_matches(num_items, plays, dim, t)  #generate plays number of pairs uniformly from the 100 items generated above using the 100C2 pairs     

    left = []
    right = []
    
    q3 = input("Original function(0) or block function(1)?")
    sys.argv.append(q3)

    q4 = input("which block? block1(0),block2(2)block3(4), etc...if no block, then put anything, it will be taken care of by the previous question:)")
    sys.argv.append(q4)

    random_matrix = np.random.rand(dim, dim)

    q5 = input("linear(0) or non-linear(1) mapping?")
    sys.argv.append(q5)
    
    
    #U_transformed = np.dot(random_matrix,U)
    q6 = list(map(int, input("which block functions you want to add?(separated by space)").split()))
    print("q6 = ", q6)

    q6_length = len(q6)
    input_block = np.zeros(q6_length)
    for q6_count in range(q6_length):
        input_block[q6_count] = q6[q6_count]
        print("input_block[q6_count] = ", input_block[q6_count])

    input_block = np.array(input_block).astype(int)

    q7 = input("Number of features for random ReLU mapping?(put double of the total block number)")
    sys.argv.append(q7)
    print("U shape = ", U.shape)
    
    


    if int(sys.argv[5]) == 0:
        U_transformed = U
        #print("U_transformed shape = ", U_transformed.shape, U_transformed)
        
    else:
        #U_transformed = np.zeros((dim, num_items))
        ##Quadratic function - 14 terms/7 blocks
        
        U_transformed = []
        jacobians = []
        IntGrad = []
        importance_scores = []
        importance_scores_perm = []
        
        print("U shape = ", U.shape)
        
        #U_x = StandardScaler().fit_transform(np.copy(np.transpose(U)))
        #U_x_tensor = torch.tensor(U_x, dtype=torch.float32)

        
        input_dim = dim #+ 2
        hidden_dim = 16
        num_layers = 4
        feature_dim = 2 * q6_length
        print("feature_dim = ", feature_dim)
        #model = RandomFeatureMapper(input_dim, hidden_dim, num_layers, feature_dim, seed = t)
        model = HybridMappingModel() #QuadraticModel(q6)##SinusoidalMappingModel() 
        #datanamearr = ["hybrid_46", "hybrid_44", "quad_24", "quad_23", "ReLU_46_4_16", "ReLU_46_8_8"]
        dataname = "hybrid_46"

        print("Hello model")
        model.eval()
        print("Hello model eval")

        def model_output_flat(x):
            return model(x.unsqueeze(0)).squeeze(0)  # Force output shape: [6] (not [1,6])

        for i in range(num_items):
            #print("i = ", i)
            #print("original features = ", np.transpose(U)[i])                 
            
            #x = torch.tensor(np.transpose(U)[i,:], dtype=torch.float32, requires_grad=True).unsqueeze(0)
            #U_quad = model(x).detach().numpy()    ##RandomFeatureMapper
            #print("U_quad.shape = ", U_quad.shape)    
            #print("transformed features = ", U_quad)          
            
            #U_quad = model(torch.from_numpy(np.transpose(U)[i]).float())  
            
            '''
                x_intgrad = U_x_tensor[i]
                ig = IntegratedGradients(model)
                print("x_intgrad shape = ", x_intgrad.shape)
                
                x_intgrad = x_intgrad.unsqueeze(0)  # make shape (1, input_dim)
            
                #baseline =  x_intgrad.mean(dim=0, keepdim=True) #torch.zeros_like(x_intgrad)
                
                baseline = torch.zeros_like(x_intgrad)  # or any other 1 × input_dim baseline

                
                print("baseline, x_intgrad = ", baseline.shape, baseline, x_intgrad.shape, x_intgrad)
                
                instance_attr = []
                print(f"x_intgrad.shape = {x_intgrad.shape}")

            
                for output_idx in range(feature_dim):        
                    attributions, _ = ig.attribute(inputs=x_intgrad, baselines=baseline, target=output_idx, return_convergence_delta=True)
                    instance_attr.append(attributions.detach().numpy())
                    
                    
                # Save per-instance IG across all outputs
                IntGrad.append(np.vstack(instance_attr))  # shape (feature_dim, input_dim)

                print(f"\nFeature contributions (IntegratedGradients): instance {i}")
                print(np.vstack(instance_attr))  # for clarity
                #print("\nFeature contributions (IntegratedGradients):", i, attributions, attributions.shape)
                
                #X = StandardScaler().fit_transform(X)
                #X_tensor = torch.tensor(X, dtype=torch.float32)
                
                deltas_ablation = []
                importance_matrix = np.zeros((feature_dim, input_dim))
                print("x_intgrad shape = ", x_intgrad.shape)
                
                
                
                #x = U_x_tensor[i]
                #x = x.unsqueeze(0)
                
                
                print("x shape = ", x_intgrad.shape)
                z_orig = model(x_intgrad).detach().numpy()
                
                
                for x_i in range(x_intgrad.shape[1]):
                    x_perturbed = x_intgrad.clone()
                    x_perturbed[:, x_i] = 0  # zero out the feature
                    x_perturbed_tensor = torch.tensor(x_perturbed, dtype=torch.float32)
                    
                    with torch.no_grad():                     
                        z_perturbed = model(x_perturbed_tensor).numpy()
                        print("x, z_orig, x_perturbed_tensor, z_perturbed, z_orig - z_perturbed = ", x, z_orig, x_perturbed_tensor, z_perturbed, z_orig - z_perturbed)
                        deltas_ablation.append(np.abs(z_orig - z_perturbed)) # L2 difference
                        importance_matrix[:, x_i] = np.abs(z_orig - z_perturbed)
                
                print("importance_matrix = ", importance_matrix.shape, importance_matrix)
                print("deltas_ablation = ", deltas_ablation)
                
                importance_scores.append(np.vstack(deltas_ablation))
                
            '''
            
            '''
                deltas_perm = []
                
                
                for x_i in range(x.shape[1]):
                    x_perm = x.clone()
                    #x_perm[:, x_i] = torch.tensor(np.random.permutation(x_perm[:, x_i]))
                    np.random.permutation(x_perm[:, x_i])
                    x_perm_tensor = torch.tensor(x_perm, dtype=torch.float32)
                    
                    with torch.no_grad():                     
                        z_perm = model(x_perm_tensor).numpy()
                        print("x, z_orig, x_perm_tensor, z_perm = ", x, z_orig, x_perm_tensor, z_perm)
                        deltas_perm.append(np.linalg.norm(z_orig - z_perm, axis=1))  
                    #importance_matrix[:, x_i] = np.linalg.norm(z_orig - z_perm, axis=1)
                
                #print("importance_matrix = ", importance_matrix.shape, importance_matrix)
                print("deltas_perm = ", deltas_perm)
                
                importance_scores_perm.append(np.vstack(deltas_perm))
            '''
            
            #U_quad = quadratic_kernel_expansion(np.transpose(U)[i,0:dim])
            #U_quad = cubic_kernel_expansion(np.transpose(U)[i,0:dim])
            #U_quad = np.exp(np.transpose(U)[i,0:dim])
            
            #U_quad = sinusoidal_mapping(np.transpose(U)[i,0:dim])
            U_quad = hybrid_mapping(np.transpose(U)[i,0:dim])

            #print("transformed features = ", U_quad)
            #print("U_transformed = ", U_transformed)
            
            U_quad = np.transpose(U_quad)
            if U_quad.shape[0] % 2 != 0:          
            	U_quad = np.append(U_quad, 1)
            	#print("transformed features = ", U_quad.shape[0])
            
            U_transformed.append(U_quad)                       
            
            #print("transformed features = ", np.array(U_transformed))
            '''
            temp = np.square(U[i])
            temp2 = U.copy()
            temp2[i] = temp
            temp2 = np.prod(temp2, axis = 0)
            U_transformed[i] = temp2
            '''

        #print("transformed features = ", np.array(U_transformed)[1,:])
        U_transformed = np.transpose(np.reshape(np.array(U_transformed),(num_items, U_quad.shape[0])))
        print("U_transformed.shape = ", U_transformed.shape)    
        
        #U_transformed = swap_terms(U_transformed, 19, 20)
        
        #U_transformed = swap_terms(U_transformed, 0, 4)
        
        #print("U_transformed after swap = ", U_transformed[0,:],U_transformed[4,:])
        
        #U_transformed = swap_terms(U_transformed, 2, 6)
        
        #print("U_transformed after swap = ", U_transformed[2,:],U_transformed[6,:])
        
        #print("U_transformed shape = ", U_transformed.shape, U_transformed[:,1])
        
        if int(sys.argv[3]) == 0:
            U_x = StandardScaler().fit_transform(np.copy((U)))
            #print("U_x = ", U_x.shape)
            U_x_tensor = torch.tensor(U_x, dtype=torch.float32)
            #print("U_x_tensor = ", U_x_tensor.shape)

            
        

            for i in range(num_items):
                x_intgrad = U_x_tensor[:,i]

                

                ig = IntegratedGradients(model)
                #ig = IntegratedGradients(WrapperModel(model))
                #print("x_intgrad shape = ", x_intgrad.shape)
                
                x_intgrad = x_intgrad.unsqueeze(0)  # make shape (1, input_dim)
                
                #baseline =  x_intgrad.mean(dim=0, keepdim=True) #torch.zeros_like(x_intgrad)
                baseline = torch.zeros_like(x_intgrad)  # or any other 1 × input_dim baseline
                
                #print("baseline, x_intgrad = ", baseline.shape, baseline, x_intgrad.shape, x_intgrad)
                        
                instance_attr = []
                #print(f"x_intgrad.shape = {x_intgrad.shape}")

                
                for output_idx in range(feature_dim):        
                    attributions, _ = ig.attribute(inputs=x_intgrad, baselines=baseline, target=output_idx, return_convergence_delta=True)
                    instance_attr.append(attributions.detach().numpy())

                #print("Jacobian : mapping")         
                jac = torch.autograd.functional.jacobian(model_output_flat, x_intgrad.view(-1))  # shape: (feature_dim, input_dim)
                #print("Jacobian shape:", jac.shape)
                jacobians.append(torch.abs(jac))  # use absolute value for magnitude

            # Save per-instance IG across all outputs
                IntGrad.append(np.vstack(instance_attr))  # shape (feature_dim, input_dim)

                #print(f"\nFeature contributions (IntegratedGradients): instance {i}")
                #print(np.vstack(instance_attr))  # for clarity
                #print("\nFeature contributions (IntegratedGradients):", i, attributions, attributions.shape)
        
            
            #print("jacobians shape = ", len(jacobians))
            jacobians = torch.stack(jacobians)
            #print("jacobians shape = ", jacobians.shape)
            
            mean_jacobian = jacobians.mean(dim=0)        
            #print("Average feature contributions (Jacobian magnitudes):", mean_jacobian)
            mean_jacobian_blocks = mean_jacobian[::2] + mean_jacobian[1::2]
            #print("Average feature contributions (Jacobian magnitudes) mean_jacobian_blocks:", mean_jacobian_blocks) 
            np.savetxt("gradient_blocks.txt", mean_jacobian_blocks)  
            
            
            #print("integradtedGradients shape = ", len(IntGrad))
            IntGrad = [torch.tensor(x) for x in IntGrad]
            IntGrad = torch.stack(IntGrad)
            #print("integradtedGradients shape = ", IntGrad.shape)
                    
            # Mean over instances
            mean_intGrad = IntGrad.abs().mean(dim=0)
            print("Mean feature contributions for each mapped feature(integrated gradient):", mean_intGrad)

            
            dataset_name =f'{dataname}'
            if not os.path.exists(dataset_name):
                os.makedirs(dataset_name)

            np.savetxt(f"{dataname}/integrated_gradient_{dataname}.txt", mean_intGrad)
            mean_intGrad_blocks = mean_intGrad[::2] + mean_intGrad[1::2]
            print("Mean feature contributions for each mapped feature(integrated gradient) mean_intGrad_blocks:", mean_intGrad_blocks)
            np.savetxt(f"{dataname}/integrated_gradient_blocks_{dataname}.txt", mean_intGrad_blocks)
            '''
            print("importance_scores = ", len(importance_scores))#, importance_scores)
            importance_scores = [torch.tensor(x) for x in importance_scores]
            importance_scores = torch.stack(importance_scores)
            print("importance_scores = ", importance_scores.shape)
            
            importance_scores_mean = np.transpose(importance_scores.mean(dim = 0))
            print("importance_scores_mean = ", importance_scores_mean.shape, importance_scores_mean)
            importance_scores_mean_blocks = importance_scores_mean[::2] + importance_scores_mean[1::2]
            print("importance_scores_mean blocks = ", importance_scores_mean_blocks.shape, importance_scores_mean_blocks)
        '''  
        '''
            print("importance_scores_perm = ", len(importance_scores_perm), importance_scores_perm)
            importance_scores_perm = [torch.tensor(x) for x in importance_scores_perm]
            importance_scores_perm = torch.stack(importance_scores_perm)
            importance_scores_perm_mean = importance_scores_perm.mean(dim = 0)
            print("importance_scores_perm_mean = ", importance_scores_perm_mean.shape, importance_scores_perm_mean)
        '''

    z = int(sys.argv[4])
    if int(sys.argv[3]) == 0:
        np.savetxt("U_transformed.txt", U_transformed)
        z = -1
        print("z = ", z)
    else:
        U_transformed = np.loadtxt("U_transformed.txt")
        print("z = ", z)

    #print("shape of U_transformed = ", U_transformed.shape, U_transformed)

    #print("U shape = ", U.shape)

    poly = PolynomialFeatures(degree=2, include_bias=False)
    U2 = np.transpose(U)
    '''
    U_mapped2 = poly.fit_transform(U2)
    print("U_mapped2 shape = ", U_mapped2.shape)
    U_mapped2 = np.transpose(U_mapped2)
    U_mapped2 = np.delete(U_mapped2, [0, 1, 2, 3, 5, 6, 7, 9, 10, 12], axis = 0)
    print("U_mapped2 shape = ", U_mapped2.shape)
    '''
    #U_final = 
    print("int(U_transformed.shape[0]/2) = ", int(U_transformed.shape[0]/2))
    y_block = np.zeros((num_pairs, int(U_transformed.shape[0]/2)))
    y_block_rand = np.zeros((num_pairs, q6_length)) #np.zeros((plays, int(U_transformed.shape[0]/2)))
    
    y_block2 = np.zeros((plays, int(U_transformed.shape[0]/2)))
    choice = np.random.rand(plays)

    #rand_choice = np.random.randint(0, int(dim/2), size=1)
    #print("single block selected = ", rand_choice)

    #rand_choice2 = np.random.choice(np.arange(int(dim/2)), size=2, replace=False)
    #print("blocks selected = ", rand_choice2)
    
    #pairs_unique = set(tuple(sorted(row)) for row in np.transpose(pairs))
    #print("Unique unordered pairs:", len(pairs_unique))
    
    #U_unique = set(tuple(sorted(row)) for row in np.transpose(U))
    #print("Unique items:", len(U_unique))
    
    print("pairs, plays = ", len(np.transpose(pairs)), plays)
    
    rand_pairs_regression = np.random.choice(len(np.transpose(pairs)), size=plays, replace=False)
    
    for i in range(plays): #(num_pairs):
        
        #i2 = np.random.randint(0, total_pairs-1, size = 1)
        i2 = rand_pairs_regression[i]
        l = int(pairs[0][i2])
        r = int(pairs[1][i2])
        
        #left.append(U[:,l])
        #right.append(U[:,r])
        
        #print("U[:,l] = ",U[:,l])
        #print("U[:,r] = ",U[:,r])
        
        
        
        f1 = U_transformed[:,l]
        f2 = U_transformed[:,r]
        
        #print("f1 shape = ", f1.shape, f1[0], f1[4])

        '''
        if z == -1:
            funct, sigm = sigmoid(U_transformed[:,l], U_transformed[:,r], dim)
        else:
            sigm_block = f1[z]*f2[z+1] - f2[z]*f1[z+1]  
            sigm = sigmoid2(sigm_block)
        '''

        ##original function and block functions all are computed.
        
        
        sigmoid_block = np.zeros(int(U_transformed.shape[0]/2))
        
        k = 0   
        #print("U_transformed.shape[0] = ", U_transformed.shape[0])     
        for j in range(U_transformed.shape[0]):
            if j % 2 == 0:
               #print("j = ", j)
               sigm_block = f1[j]*f2[j+1] - f2[j]*f1[j+1]  
               sigmoid_block[k] = sigmoid2(sigm_block)
               y_block[i,k] = sigm_block
               k = k+1
        
        if int(sys.argv[3]) != 0: 
           
           #print("choice = ", choice[i])
        
           for j in range(int(U_transformed.shape[0]/2)):
              if choice[i] <= sigmoid_block[j]:
                 #print("choice, sigmoid_block[j]", choice[i], sigmoid_block[j])
                 y_block2[i,j] = 1
              elif choice[i] > sigmoid_block[j]:
            	 #print("choice, sigmoid_block[j]", choice[i], sigmoid_block[j])
                 y_block2[i,j] = 0   

        
           
        q6_sum = 0
        for q6_count in range(q6_length):
            y_block_rand[i,q6_count] = y_block[i, input_block[q6_count]]
            q6_sum = q6_sum + y_block[i, input_block[q6_count]]    
        
        y[i] =  q6_sum 
        
        
        if int(sys.argv[3]) != 0:           
           sigm = sigmoid2(q6_sum)
           '''
           if choice[i] <= sigm:   
              y[i] =  1
           else:
              y[i] = 0
           '''
         
           
                     
        
            	
        '''
        for j in range(int(dim/2)):
            if sigmoid_block[j] >= 0.5:
               #print("choice, sigmoid_block[j]", choice, sigmoid_block[j])
               y_block[i,j] = 1
            elif sigmoid_block[j] < 0.5:
            	#print("choice, sigmoid_block[j]", choice, sigmoid_block[j])
            	y_block[i,j] = -1
        '''  
        '''   
        positive_class = U[y == 1]
    	negative_class = X[y == -1]
    	plt.scatter(positive_class[:, -1], positive_class[:, 1], c='blue', label='Positive Class')
    	plt.scatter(negative_class[:, 0], negative_class[:, 1], c='red', label='Negative Class')
    	plt.xlabel('Feature 1')
    	plt.ylabel('Feature 2')
	'''
        #print("Y = ", y)
        #print("left = ", left)
    
        left.append(torch.from_numpy(U[:,l]).unsqueeze(0))
        right.append(torch.from_numpy(U[:,r]).unsqueeze(0))
    
        #print("left = ", left)
    
        x,x_prime = torch.cat(left,dim=0),torch.cat(right,dim=0)
    
        #print("x,x_prime = ",x,x_prime)

    
    np.savetxt(f"y_original.txt", y)
    np.savetxt(f"y_block.txt", y_block_rand)
    for q6_count in range(q6_length):
        np.savetxt(f"y_block_{q6_count+1}_original.txt", y_block[:, input_block[q6_count]])

    
    
    np.savetxt(f"left_items.txt", x.float())
    np.savetxt(f"right_items.txt", x_prime.float())
    
    #x = np.transpose(x)
    #x_prime = np.transpose(x_prime)
    U = np.transpose(U)
    print("y = ", y)
    print("y_block_rand = ", y_block_rand.shape, y_block_rand)
    print("y_block = ", y_block.shape, y_block)
    
    pairs = set(tuple(sorted(row)) for row in np.concatenate([x,x_prime], axis=1))
    print("Unique unordered pairs:", len(pairs))
    
    combined = np.concatenate([x,x_prime], axis=1)
    pair_dict = {}

    for idx, row in enumerate(combined):
        # Treat (A, B) and (B, A) as unordered by sorting the row
        sorted_pair = tuple(np.sort(row))
    
        if sorted_pair in pair_dict:
           pair_dict[sorted_pair].append(idx)
        else:
           pair_dict[sorted_pair] = [idx]

    # Print duplicates (pairs that appear more than once)
    print("Duplicate unordered pairs and their row indices:\n")
    for pair, indices in pair_dict.items():
        if len(indices) > 1:
           print(f"Pair: {pair}\n  Appears at rows: {indices}\n")
    
    print("y_block_rand shape = ", y_block_rand.shape)

    return x.float(), x_prime.float(), U, y, y_block_rand
    
def generate_data_sigmoid(num_items, num_pairs, plays, dim, t):

    y = np.zeros(num_pairs)

    #U = generate_items_uniform(num_items, dim, t)

    U = generate_random_points_covariance_matrix(num_items, dim, t)

    total_pairs = int(num_pairs)

    pairs_1 = np.zeros((2 , total_pairs))
        
    k = 0
    for i in range(num_items - 1):
        for j in range(i+1, num_items):
            pairs_1[0][k] = i
            pairs_1[1][k] = j
            k = k + 1

    left = []
    right = []
    
    for i in range(num_pairs):

        l = int(pairs_1[0][i])
        r = int(pairs_1[1][i])
        
        #left.append(U[:,l])
        #right.append(U[:,r])
        
        #print("U[:,l] = ",U[:,l])
        #print("U[:,r] = ",U[:,r])


        y[i], sigm = sigmoid(U[:,l], U[:,r], dim)
    
        left.append(torch.from_numpy(U[:,l]).unsqueeze(0))
        right.append(torch.from_numpy(U[:,r]).unsqueeze(0))
    
        #print("left = ", left)
    
        x,x_prime = torch.cat(left,dim=0),torch.cat(right,dim=0)
    
        #print("x,x_prime = ",x,x_prime)

    U = np.transpose(U)

    return x.float(), x_prime.float(), U, y


#generate_data(100, 10000, 10000, 2, 0)
    


