# Imports
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
import math
import argparse
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
from transformers import AutoImageProcessor, Dinov2Model
import torch
from datasets import load_dataset
from kmeans_pytorch import kmeans

parser = argparse.ArgumentParser(description="arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-w", "--weighted", action="store_true", help="weighted version")

args = parser.parse_args()
#print(args)
#print(args.weighted)
weighted_version= args.weighted
print("weighted version = " + str(weighted_version))
if not weighted_version:
  print("use option -w to use the weighted version of the cross entropy")

# Specify device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

nb_categories = 101
# Load data from huggingface
dataset_string = "food%d" % nb_categories
print("dataset = %s" % dataset_string)
dataset = load_dataset(dataset_string)
label_string = "label"
if nb_categories == 100:
  label_string = "fine_label"
  

# Extract data
train_data = dataset["train"] #train images
train_images = train_data['image']
train_labels = train_data[label_string] #fine label are the actual classes, the other entry in the dict are superclasses (ignore)

test_data = dataset["validation"] #test images
test_images = test_data['image']
test_labels = test_data[label_string]


# Load Dinov2
image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
model = Dinov2Model.from_pretrained("facebook/dinov2-base")
model = model.to(device) #move to GPU for fast image encoding

# Preprocess data using DinoV2
# Note may be better to prepocess all data once and saving it into a torch tensor that can be loaded

extract = 75800 #50000 # use 50000 to extract all (it can be slow)

hidden_layer_size = 1000

# Get labels
ydata_train = torch.empty(0,1,device=device)
for i in tqdm(range(len(train_labels[:extract]))):
  label = torch.Tensor([train_labels[i]]).view(1,1).to(device)
  ydata_train = torch.cat((ydata_train, label), dim=0)

# DinoV2 latent dimension
latent_size = 768
# Extraction Batch size
extraction_batch_size = 50
# Initialize an empty tensor to store all latent hidden states (DINOv2 image encodings)
xdata_train = torch.empty(0, latent_size, device=device)

# Process images in batches
for i in tqdm(range(0, len(train_images[:extract]), extraction_batch_size)):
    batch_images = train_images[i:i+extraction_batch_size]

    # Preprocess for Dino
    inputs = image_processor(batch_images, return_tensors="pt").to(device)

    # Get latent without gradients
    with torch.no_grad():

        outputs = model(**inputs)
        # Extract latent
        last_hidden_states = outputs.pooler_output

    # Flatten latent into feature vector
    last_hidden_states = last_hidden_states.view(len(batch_images), -1)

    # Concatenate the current batch of hidden states to the tensor
    xdata_train = torch.cat((xdata_train, last_hidden_states), dim=0)
    
    
    
    
#Same for test

# Get labels
ydata_test = torch.empty(0,1,device=device)
for i in tqdm(range(len(test_labels[:extract]))):
  label = torch.Tensor([test_labels[i]]).view(1,1).to(device)
  ydata_test = torch.cat((ydata_test, label), dim=0)

# Initialize an empty tensor to store all latent hidden states
xdata_test = torch.empty(0, latent_size, device=device)

# Process test images in batches
for i in tqdm(range(0, len(test_images[:extract]), extraction_batch_size)):
    batch_images = test_images[i:i+extraction_batch_size]

    # Preprocess for Dino
    inputs = image_processor(batch_images, return_tensors="pt").to(device)

    # Get latent without gradients
    with torch.no_grad():
        outputs = model(**inputs)

    # Extract latent
    last_hidden_states = outputs.pooler_output

    # Flatten latent into feature vector
    last_hidden_states = last_hidden_states.view(len(batch_images), -1)

    # Concatenate the current batch of hidden states to the tensor
    xdata_test = torch.cat((xdata_test, last_hidden_states), dim=0)    
    
    
# Transform integer labels into one-hot vectors
ydata_train = torch.nn.functional.one_hot(ydata_train.long(), num_classes=nb_categories).float()
ydata_test = torch.nn.functional.one_hot(ydata_test.long(), num_classes=nb_categories).float()

# Combine xdata and ydata into a tuple for each data point
train_dataset = list(zip(xdata_train, ydata_train))
test_dataset = list(zip(xdata_test, ydata_test))



class Neural_Pathways(nn.Module):
    def __init__(self,
                 input_size=latent_size, #Latent size of DINOV2 model
                 hidden_size=10,
                 output_size=100, #CIFAR100 has 100 classes
                 num_mlps=3,
                 ):

        super(Neural_Pathways, self).__init__()

        # Shallow Networks for Discovering Prototypes

        # Note that in this case is a multiclass classification problem, we want to reuse the internal feature but need to have the last prediction layer separatedly to extend the networks later
        self.mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_size, hidden_size),
                nn.BatchNorm1d(hidden_size),
                nn.PReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.BatchNorm1d(hidden_size),                
            ) for _ in range(num_mlps)
        ])

        self.output_layer = nn.ModuleList([
            nn.Sequential(
                nn.PReLU(init=1.0),
                nn.Linear(hidden_size, output_size),

            ) for _ in range(num_mlps)
        ])

        # Prototypes
        self.prototypes = nn.ParameterList([
            nn.Parameter(12 * torch.rand(size=(1, input_size)) - 6) for i in range(num_mlps) #input seems to be near random -6 to 6, may need to adjust this in the future
        ])

    def distances_to_prototypes(self,x):

        # Calculate square Euclidean distances for all prototypes
        distances = torch.cat([torch.cdist(x, prototype) for prototype in self.prototypes], axis=1)

        return distances


    def predict_discover_prototypes(self, x):

        # Get distances
        distances = self.distances_to_prototypes(x)

        # Calculate softmax weightings based on distances (- because we want to assign higher probabilities to nearby prototypes)
        weightings = F.softmax(-distances, dim=1)

        # Get predictions from all MLPs
        predictions_list = []

        # Iterate over pairs of self.mlps and self.output_layer
        for mlp, out in zip(self.mlps, self.output_layer):
          # Apply the current pair to the input tensor
          current_prediction = out(mlp(x))

          # Append the current prediction to the list
          predictions_list.append(current_prediction)

        # Concatenate the predictions along the second dimension (horizontal concatenation)
        predictions = torch.stack(predictions_list, dim=2)

        # Expand the weightings tensor to match the size of the predictions tensor
        expanded_weightings = weightings.unsqueeze(1)

        # Multiply each element of predictions by the corresponding element of weightings
        weighted_predictions = predictions * expanded_weightings

        # Sum along the third dimension to get the weighted sum
        total_prediction = weighted_predictions.sum(dim=2)


        return total_prediction

    def calculate_W(self):

        # Stack the prototypes into a single matrix
        prototypes = torch.stack([prototype[0] for prototype in self.prototypes])

        # Calculate the pairwise Euclidean distances
        W = torch.cdist(prototypes, prototypes)

        return W

    def W_penalty(self):

        # Get W
        W = self.calculate_W()

        # Compute norm
        frobenius_norm = torch.norm(W)

        # Compute penalty
        penalty = torch.exp(-frobenius_norm)

        return penalty    
        
        
        
class Pathway(nn.Module):
    def __init__(self,
                 proto_neural_pathway, #step 1 fine-tuned model initial layers
                 output_layer, #output layer with trained weights
                 proto_hidden_size = 10, #step 1 model hidden size (must be the same as hidden_size of class Neural_Pathways)
                 hidden_size=10, #new model hidden size
                 ):
        super(Pathway, self).__init__()


        # Shallow Network for Discovering Prototypes (without output head)
        self.MLP = proto_neural_pathway

        # Shallow network output layer head
        self.output_layer = output_layer

        # Network Extensions
        self.MLP_extension = nn.Sequential(
        
            nn.PReLU(init=1.0),
            nn.Linear(proto_hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(proto_hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(hidden_size,proto_hidden_size),         
            nn.BatchNorm1d(hidden_size),
        )

        # Initialize pathway almost to identity, but add Gaussian noise to allow gradients to flow
        for m in self.MLP_extension:

          if isinstance(m, nn.Linear):

            # Initialize weight matrix to identity
            torch.nn.init.eye_(m.weight)

            # Initialize bias to zeros
            torch.nn.init.zeros_(m.bias)

            # Add a bit of noise to the weights
            noise_weight = 1e-0 * torch.randn_like(m.weight)
            m.weight.data.add_(noise_weight)

            # Add a bit of noise to the bias
            noise_bias = 1e-0 * torch.randn_like(m.bias)
            m.bias.data.add_(noise_bias)

    def forward(self,x):

        # Pass to pretrained MLP
        z = self.MLP(x)
        # Pass to new MLP extension
        z = self.MLP_extension(z)

        #z = nn.functional.normalize(z, dim=-1, p=2)        
        # Pass to output layer
        y = self.output_layer(z)

        return y
        
# Specify number of neural pathways
num_mlps = 4

# Create data loaders for training and testing
batch_size = len(train_dataset)  # Adjust as needed
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Standard loss for multiclass classification tasks
criterion = nn.CrossEntropyLoss()

# Training loop epochs
num_epochs = 10

# Initialize Neural Pathways class for learning prototypes
neural_pathways = Neural_Pathways(num_mlps=num_mlps, hidden_size=hidden_layer_size,output_size=nb_categories).to(device)

# Optimizer (Use parameter groups to optimize some parameters at different rates)
# Ideally have learning rate for prototypes be larger, so the algorithm focuses on correctly placing the neural pathways in the input space
optimizer = torch.optim.Adam([
    {'params': [param for name, param in neural_pathways.named_parameters() if 'prototypes' not in name], 'lr': 1e-5},  # Default learning rate for most parameters
    {'params': [neural_pathways.prototypes[prototype_idx] for prototype_idx in range(num_mlps)], 'lr': 1e-5}  # Different learning rate for specific parameters
])

# W penalty term (Imposes a prior that evenly distributed prototypes are better, which may not necessarily be the case)
gamma = 0.0

# Initialize lists to store training metrics and prototype locations
prototype_locations = {}

# Store initial prototype locations
for prototype_idx in range(num_mlps):
    #print(neural_pathways.prototypes[prototype_idx].detach().cpu().numpy()[0])
    prototype_locations[f'prototype_{prototype_idx}_location'] = [list(neural_pathways.prototypes[prototype_idx].detach().cpu().numpy()[0])]


# Whether to clip prototype to be within x range (in this particular case between 0 and 1)
clip_in_range = False


training_loss_list = []
total_loss_list = []
w_penalty_list = []

# Prototypes training loop
for epoch in range(0): #tqdm(range(num_epochs)):
    for inputs, targets in train_loader:

        optimizer.zero_grad()

        # Forward pass
        outputs = neural_pathways.predict_discover_prototypes(inputs.to(device))

        # Compute loss
        loss = criterion(outputs, targets.to(device).view(batch_size,-1))
        if gamma:
            W_penalty = neural_pathways.W_penalty()
            total_loss = loss + gamma * W_penalty
            w_penalty_list.append(gamma * W_penalty.item())
        else:
            total_loss = loss

        # Backward pass and optimization
        total_loss.backward()
        optimizer.step()

        # Store training metrics
        training_loss_list.append(loss.item())
        total_loss_list.append(total_loss.item())


        # Store initial prototype locations
        if epoch+1 % 10 == 0: # dont run out of memory
          for prototype_idx in range(num_mlps):
            prototype_locations[f'prototype_{prototype_idx}_location'].append(list(neural_pathways.prototypes[prototype_idx].detach().cpu().numpy()[0]))        

if True:
    for inputs, targets in train_loader:            
        cluster_ids_x, cluster_centers = kmeans(X=inputs.to(device), num_clusters=num_mlps, distance='euclidean', device=device)
    print(cluster_centers)
    neural_pathways.prototypes = nn.ParameterList([nn.Parameter(cluster_centers[prototype_id,:].view(1,-1)) for prototype_id in range(num_mlps)])
    #print(neural_pathways.prototypes)
    
    #for prototype_idx in range(num_mlps):
    #    print(neural_pathways.prototypes[prototype_idx])
    #    [neural_pathways.prototypes[prototype_idx] for prototype_idx in range(num_mlps)]
    #    neural_pathways.prototypes[prototype_idx] = cluster_centers[prototype_idx]      

    #for prototype_idx in range(num_mlps):
    #    #print(neural_pathways.prototypes[prototype_idx].detach().cpu().numpy()[0])
    #    prototype_locations[f'prototype_{prototype_idx}_location'] = [list(neural_pathways.prototypes[prototype_idx].detach().cpu().numpy()[0])]        
            
            
            
# Get W matrix after training prototypes
W = neural_pathways.calculate_W()

# Extract each row
rows = W.unbind(0)

# Find the minimum value in each row that is not equal to 0
min_values = torch.tensor([row[row.nonzero(as_tuple=True)[0]].min() for row in rows])

# Store these minimum values in a new tensor
radii = min_values.unsqueeze(1)  # Make it a column tensor



# Import networks from neural_pathways (note we need to import total network chopped: main network + output layer, we add layers initialized to identity between these two to preserve the same output in this classification task)
pathway1 = Pathway(proto_neural_pathway = neural_pathways.mlps[0], output_layer = neural_pathways.output_layer[0],proto_hidden_size=hidden_layer_size, hidden_size=hidden_layer_size).cuda()

# Check that pathway has been correctly initialized to identity
#assert torch.allclose( neural_pathways.output_layer[0](neural_pathways.mlps[0](xdata_train[:100].unsqueeze(1).cuda())),
#                      pathway1(xdata_train[:100].unsqueeze(1).cuda()),
#                      atol=1e-4, rtol=1e-4) # adjust tolerance



# We need to define an additional cross entropy for points outside the Voronoi cell (we dont want to directly average all cross entropies, we want to weight them as a function of the distance to the prototype)
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self):
        super(WeightedCrossEntropyLoss, self).__init__()

    def forward(self, input, target, weights):
        """
        Args:
            input (torch.Tensor): raw logits from the model
            target (torch.Tensor): ground truth labels
            weights (torch.Tensor): weight tensor for each value in the batch
        """
        # Compute cross entropy loss
        print(target)
        print(target.shape)
        ce_loss = nn.functional.cross_entropy(input, target, reduction='none')


        # Apply element-wise weights
        weighted_loss = ce_loss * weights

        # Calculate the mean loss for the batch
        loss = torch.mean(weighted_loss)

        return loss
        
        
# Training loop epochs
num_epochs = 2000

# Lambda is a hyperparameter controlling the importance of points outside the radius during training
lambd = 0.0 #0.0000001

# Define weighted crossentropy for points outside voronoi cell
w_crossentropy = WeightedCrossEntropyLoss()

infTensor = torch.tensor(float('inf')).to(device)
now = datetime.now()
directory = "Model_" + now.strftime("%Y_%m_%d_%H_%M_%S")
os.mkdir(directory)


prototype_parameters = neural_pathways.prototypes
prototypes = torch.zeros(num_mlps, latent_size).to(device)
for i in range(num_mlps):
  prototypes[i,:] = prototype_parameters[i].view(-1)
nb_iterations = 10000 #15000 #10000 #15000
# Training loop epochs
num_epochs_baseline = nb_iterations #10000 #nb_iterations

step_size = 1e-3
for i in range(num_mlps):
  print("leaf %d/%d" % (i+1, num_mlps))
  pathway1 = Pathway(proto_neural_pathway = neural_pathways.mlps[i], output_layer = neural_pathways.output_layer[i], proto_hidden_size=hidden_layer_size, hidden_size=hidden_layer_size).to(device)
  #r1 = radii[i].item()
  #prototype1 = neural_pathways.prototypes[i]

  # Create data loaders for training and testing
  batch_size = int(len(train_dataset))  # Adjust as needed
  train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

  # Optimizer
  #optimizer = optim.SGD(pathway1.parameters(), lr=step_size, momentum=0.9)
  #optimizer = optim.AdamW(pathway1.parameters(), weight_decay=0.00, lr=step_size)
  optimizer = optim.AdamW(pathway1.parameters(), lr=step_size)

    # Initialize lists to store training metrics and prototype locations
  training_loss_pathway = []
  training_loss_inside_radius = []
  training_loss_outside_radius = []

    # Neural Pathway training loop
  #for epoch in tqdm(range(num_epochs)):
  for epoch in tqdm(range(0)):
    #cpt = 0
    for inputs, targets in train_loader:
      #cpt = cpt + 1
      #print(cpt)
      optimizer.zero_grad()
            # Match dimensions
      targets = targets.to(device).view(-1, nb_categories) #.unsqueeze(1)
        
        
      inputs = inputs.to(device) #.unsqueeze(1)

            # Forward pass
      #print(inputs.shape)
      outputs = pathway1.forward(inputs)
      #print(outputs.shape)

            # Compute distance between input location and prototype for neural pathway
      distance_to_prototype = torch.cdist(inputs,prototypes).clone().detach()
      #print(distance_to_prototype.shape)
      mask_inside_radius = torch.argmin(distance_to_prototype, dim=1) == i  

      
      #print(mask_inside_radius.shape)
      #print(mask_inside_radius)

            # Create boolean masks based on the distance criterion
      #mask_inside_radius = distance_to_prototype < r1
      #print(mask_inside_radius.shape)

            # Use boolean indexing to split predictions
      #print(outputs.shape)
      #print(targets.shape)
      outputs_inside_radius = outputs[mask_inside_radius]
      targets_inside_radius = targets[mask_inside_radius] #.view(batch_size,100)


      if (not epoch) and weighted_version:
        weight_frequency = torch.sum(targets_inside_radius, dim=0)

        print(weight_frequency)                
        weight_frequency[weight_frequency < 0.5] = 1.0

        weight_frequency = 1.0 /weight_frequency
        criterion = nn.CrossEntropyLoss(weight=weight_frequency)      

      #print(outputs_inside_radius.shape)


      #print("targets_inside_radius")      
      #print(targets_inside_radius)      
      #print(targets_inside_radius.shape)      
      #if nb_categories == 10:
      #    targets_inside_radius = torch.argmax(targets_inside_radius,dim=1)
      #    print(targets_inside_radius)
      #    print(targets_inside_radius.shape)
        
      loss_inside_radius = criterion(outputs_inside_radius, targets_inside_radius)

      loss = loss_inside_radius
        
      if lambd:
        outputs_outside_radius = outputs[~mask_inside_radius]
        targets_outside_radius = targets[~mask_inside_radius]
        #print(targets_outside_radius.shape)
        #print(outputs_outside_radius.shape)
        with torch.no_grad():

              # Calculate square Euclidean distances to all prototypes
          distances = neural_pathways.distances_to_prototypes(inputs.detach())
          distances = torch.norm(distances, p=2, dim=1).unsqueeze(1)

              # Find the minimum value excluding zeros
        #print(W[i])
          find_minimum_of_tensor = distances*W[i] # Note we are taking row 0 here because we are using neural pathway 1, this needs to change for every neural pathway
          #print(find_minimum_of_tensor.shape)
          min_value, _ = torch.min(torch.where(find_minimum_of_tensor == 0, infTensor, find_minimum_of_tensor), dim=1)
        #print(min_value)

          # Compute penalty weights
          penalty_weight = torch.exp(-min_value)[~mask_inside_radius]


        # Compute loss

        # Compute error for each input, multiply times its distance based weight, take mean, and multiply times hyperparameter lambda
      #if lambd:
        #loss_outside_radius = lambd*torch.mean(penalty_weight*(((outputs_outside_radius-targets_outside_radius)**2)))
          
        loss_outside_radius = lambd*w_crossentropy(outputs_outside_radius,targets_outside_radius,penalty_weight)
        loss = loss + loss_outside_radius
      #else:
      #  loss_outside_radius = 0.0

        # Total loss for the neural pathway
       # + loss_outside_radius

        # Backward pass and optimization

      #if not (epoch % 200):
      #  print(epoch,loss)        
        
      loss.backward()
      optimizer.step()
    #print(epoch,loss)    
  #print(loss)
        # Store training metrics
    #training_loss_pathway.append(loss.item())
    #training_loss_inside_radius.append(loss_inside_radius.item())

    #if lambd:      
    #  training_loss_outside_radius.append(loss_outside_radius.item())

  torch.save(pathway1.state_dict(), os.path.join(directory, "model_%d.pt" % i))
print("training done")



# Initialize dataloader for test set
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Criterion as before

#criterion = nn.MSELoss(reduction='sum')

# Initialize loss
total_loss = 0.0

# Extract prototypes from neural pathways class
prototype_parameters = neural_pathways.prototypes
prototypes = torch.zeros(num_mlps, latent_size).to(device)
for i in range(num_mlps):
  prototypes[i,:] = prototype_parameters[i].view(-1)

# Test loop: loop over all pathways
for i in range(num_mlps):
  #print(i)
  nb_test = 0
  # Get pathway architecture
  #pathway = Pathway(proto_neural_pathway = neural_pathways.mlps[i], proto_neural_pathway_output = neural_pathways.mlps_output[i],hidden_size=hidden_layer_size,).to(device)
  # Load pathway i
  pathway = Pathway(proto_neural_pathway = neural_pathways.mlps[i], output_layer = neural_pathways.output_layer[i], proto_hidden_size=hidden_layer_size, hidden_size=hidden_layer_size).to(device)
      
  pathway.load_state_dict(torch.load(os.path.join(directory, "model_%d.pt" % i)))
  pathway.eval()

  # Loop over dataloader
  for inputs, targets in test_loader:

    # Move inputs and targets to device
    targets = targets.to(device).view(-1, nb_categories)
    
    inputs = inputs.to(device) #.view(-1,1)
    nb_test = nb_test + targets.shape[0]
    #print(inputs.shape)
    #print(targets.shape)

    # Get masks based on input distance to prototypes
    distance_to_prototype = torch.cdist(inputs,prototypes).clone()
    mask_inside_radius = torch.argmin(distance_to_prototype, dim=1) == i

    # Only pass as input to the model data close to the prototype to avoid unnecessary computations
    outputs_inside_radius = pathway(inputs[mask_inside_radius]) #.view(batch_size,-1)
    #print(outputs_inside_radius.shape)
    # Mask targets
    targets_inside_radius = targets[mask_inside_radius]
    # Compute loss
    #aaa = torch.argmax(outputs_inside_radius,dim=1)
    #print(aaa.shape)
    loss_inside_radius = torch.sum(torch.argmax(outputs_inside_radius,dim=1) == torch.argmax(targets_inside_radius,dim=1))

    # Add to total loss
    total_loss = total_loss + loss_inside_radius

# Note that we need to divide by the number of neural pathways, we are computing the average
#total_loss = total_loss #/num_mlps

total_loss_ours = total_loss

print("number of test examples = %d" % nb_test)
print("total loss = {:.15f}".format(100.0 * total_loss / nb_test))


hidden_layer_size = 1000


class PathwayTest(nn.Module):
    def __init__(self, num_mlps, input_size = 1, hidden_size=10, output_size=1):
        super(PathwayTest, self).__init__()

        number_of_leaves = math.ceil(math.sqrt(num_mlps))
        # Shallow Network for Discovering Prototypes
        self.mlps = nn.Sequential(
                nn.Linear(input_size, number_of_leaves * hidden_size),
                nn.PReLU(init=1.0),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(init=1.0),
                nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                nn.BatchNorm1d(number_of_leaves * hidden_size),
                nn.PReLU(init=1.0),
                nn.Linear(number_of_leaves * hidden_size,number_of_leaves * hidden_size),
                nn.BatchNorm1d(number_of_leaves * hidden_size),
                nn.PReLU(init=1.0),
                nn.Linear(number_of_leaves * hidden_size,number_of_leaves * hidden_size),
                nn.BatchNorm1d(number_of_leaves * hidden_size),
                
                #nn.PReLU(init=1.0),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * output_size),

            )


        # Network Extensions
        self.MLP_extension = nn.Sequential(
            #nn.PReLU(init=1.0),
            #nn.Linear(number_of_leaves * output_size, number_of_leaves * hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(number_of_leaves * hidden_size,number_of_leaves * hidden_size),
            nn.BatchNorm1d(number_of_leaves * hidden_size),
            
            #nn.PReLU(init=1.0),
            #nn.Linear(number_of_leaves * hidden_size,number_of_leaves * hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(number_of_leaves * hidden_size,output_size),
        )

        # Automatically initialize MLP extension to identity
    #    self.initialize_extension()

    # Identity initialization using I for linear layer and no bias
    #def initialize_identity(self,m):
        #print("m in MLP extension")
        #for m in self.mlps:

        #  if isinstance(m, nn.Linear):
        #    torch.nn.init.kaiming_normal_(m.weight)
            #torch.nn.init.normal_(m.weight, std = 0.1)
            #print(m.weight.shape)
            #m.weight.data = m.weight.data + torch.eye(m.weight.shape[0], m.weight.shape[1])
            #print(m.weight)
        #    if m.bias is not None:
        #        torch.nn.init.zeros_(m.bias)

        #for m in self.MLP_extension:

        #  if isinstance(m, nn.Linear):
        #    torch.nn.init.kaiming_normal_(m.weight)
            #torch.nn.init.normal_(m.weight, std = 0.1)
            #print(m.weight.shape)
            #m.weight.data = m.weight.data + torch.eye(m.weight.shape[0], m.weight.shape[1])
            #print(m.weight)
        #    if m.bias is not None:
        #        torch.nn.init.zeros_(m.bias)
          #else:

            #print(m.weight)
        #print("end")

    #def initialize_extension(self):
    #    self.initialize_identity(self.MLP_extension)

    def forward(self,x):

        # Pass to pretrained MLP
        z = self.mlps(x)
        # Pass to new MLP extension
        #z = nn.functional.normalize(z, dim=-1, p=2)
        y = self.MLP_extension(z)
        #print(self.MLP_extension[0].weight)

        return y
        
        

# Standard loss for regression tasks
criterion = nn.MSELoss()


# Initialize Neural Pathways class for learning prototypes
neural_pathways_test = PathwayTest(input_size = latent_size, hidden_size=hidden_layer_size, output_size=nb_categories, num_mlps=num_mlps).to(device)

# Optimizer (Use parameter groups to optimize some parameters at different rates)
#optimizer = torch.optim.SGD([
#    {'params': [param for name, param in neural_pathways.named_parameters() if 'prototypes' not in name], 'lr': 1e-3},  # Default learning rate for most parameters
#    {'params': [neural_pathways.prototypes[prototype_idx] for prototype_idx in range(num_mlps)], 'lr': 1e-2}  # Different learning rate for specific parameters
#], momentum=0.9)

optimizer = torch.optim.AdamW([
    {'params': [param for name, param in neural_pathways_test.named_parameters() if 'prototypes' not in name], 'lr': step_size} # Different learning rate for specific parameters
])


training_loss_list = []
total_loss_list = []
w_penalty_list = []

# Prototypes training loop
for epoch in tqdm(range(num_epochs_baseline)):
    for inputs, targets in train_loader:

        optimizer.zero_grad()
        # Forward pass
        outputs = neural_pathways_test.forward(inputs.to(device))

        # Compute W penalty

      #optimizer.zero_grad()
            # Match dimensions
      #targets = targets.to(device).view(batch_size,-1) #.unsqueeze(1)
      #inputs = inputs.to(device) #.unsqueeze(1)

        # Compute loss
        loss = criterion(outputs, targets.to(device).view(batch_size,-1))
        total_loss = loss

        # Backward pass and optimization
        total_loss.backward()
        optimizer.step()

        # Store training metrics
        training_loss_list.append(loss.item())
        #total_loss_list.append(total_loss.item())


        #for prototype_idx in range(num_mlps):
        #  prototype_locations[f'prototype_{prototype_idx}_location'].append(list(neural_pathways.prototypes[prototype_idx].detach().cpu().numpy()[0]))


#print(outputs)
#print(targets)
print(total_loss)

print("end of training")

### testing on the test set
#criterion = nn.MSELoss(reduction='sum')
#criterion = nn.MSELoss()
total_loss = 0.0
neural_pathways_test.eval()
nb_test = 0.0
for inputs, targets in test_loader:
  targets = targets.to(device).view(-1,nb_categories)
  inputs = inputs.to(device) #.view(-1,1)
  outputs = neural_pathways_test.forward(inputs) #.view(-1)
    
  nb_test = nb_test + targets.shape[0]
            # Compute distance between input location and prototype for neural pathway
  #distance_to_prototype = torch.cdist(inputs,prototype1).clone()

            # Create boolean masks based on the distance criterion
  #  mask_inside_radius = distance_to_prototype < r1
      #print(mask_inside_radius.shape)

            # Use boolean indexing to split predictions
  #  outputs_inside_radius = outputs[mask_inside_radius[:,0]]
    #outputs_outside_radius = outputs[~mask_inside_radius[:,0]]
   # targets_inside_radius = targets[mask_inside_radius[:,0]]
    #targets_outside_radius = targets[~mask_inside_radius[:,0]]


   # with torch.no_grad():

              # Calculate square Euclidean distances to all prototypes
  #    distances = neural_pathways.distances_to_prototypes(inputs.detach())
  #    distances = torch.norm(distances, p=2, dim=1).unsqueeze(1)

              # Find the minimum value excluding zeros
   #   find_minimum_of_tensor = distances*W[i] # Note we are taking row 0 here because we are using neural pathway 1, this needs to change for every neural pathway
  #    min_value, _ = torch.min(torch.where(find_minimum_of_tensor == 0, infTensor, find_minimum_of_tensor), dim=1)


          # Compute penalty weights
    # penalty_weight = torch.exp(-min_value)[~mask_inside_radius[:,0]]


        # Compute loss
  total_loss = total_loss + torch.sum(torch.argmax(outputs,dim=1) == torch.argmax(targets,dim=1)) 
        # Compute error for each input, multiply times its distance based weight, take mean, and multiply times hyperparameter lambda
      #loss_outside_radius = lambd*torch.mean(penalty_weight*(((outputs_outside_radius-targets_outside_radius)**2)[:,0]))

        # Total loss for the neural pathway
      #loss = loss_inside_radius + loss_outside_radius
print("total loss = {:.15f}".format(total_loss))     

#print(outputs)
#print(targets)
print(total_loss / nb_test * 100)

total_loss_baseline = total_loss


#now = datetime.now()
if weighted_version:
  weighted_string = "weighted"
else:
  weighted_string = "unweighted"
filepath = "results/results_dyno_%s_%d_" % (dataset_string, nb_categories)
  
result_file = filepath + weighted_string + now.strftime("_%Y_%m_%d_%H_%M_%S.txt")
f = open(result_file, "w")
f.write("{:.15f}\n".format(total_loss_ours  / nb_test))
f.write("{:.15f}\n".format(total_loss_baseline  / nb_test))
f.close()   


                    
