import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from fbm import fbm, fgn, times
import matplotlib.pyplot as plt
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
import math
from datetime import datetime
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = %s" % device)

step_size = 1e-5
# Number of dataset samples
#num_samples = 50
num_samples = 30
nb_iterations = 10000
min_x = -1
max_x = 1

def ackley_function(x, y, z):

  term1 = -20 * np.exp(-0.2 * np.sqrt((x**2 + y**2 + z**2) / 3.0  ))
  term2 = -np.exp( (np.cos(2 * np.pi * x) + np.cos(2 * np.pi * y) + np.cos(2 * np.pi * z) ) / 3.0       )

  ackley_value = term1 + term2 + np.e + 20

  return ackley_value


def rastrigin_function(x, y, z):

  return 30 + x**2 - 10 * np.cos(2 * np.pi * x) + y**2 - 10 * np.cos(2 * np.pi * y)  + z**2 - 10 * np.cos(2 * np.pi * z)   

def generate_dataset(num_samples):
    # Generate x data
  x1 = np.linspace(min_x, max_x, num_samples)
  x2 = np.linspace(min_x, max_x, num_samples)
  x3 = np.linspace(min_x, max_x, num_samples)
  x_data = np.array(np.meshgrid(x1, x2, x3)).T.reshape(-1, 3)
  #print(x_data)

    # Calculate y data based on the given formula
  #y_data = ackley_function(x_data[:, 0], x_data[:, 1])
  #y_data = rastrigin_function(x_data[:, 0], x_data[:, 1], x_data[:, 2])
  y_data = ackley_function(x_data[:, 0], x_data[:, 1], x_data[:, 2])
  #y_data = np.sin(10*x_data[:, 0]) + np.cos(x_data[:, 1]) + np.sin(x_data[:, 0] * x_data[:, 1])

    # Reshape y_data to have shape (num_samples, 1)
  y_data = y_data.reshape(-1, 1)

  return x_data, y_data

xdata, ydata = generate_dataset(num_samples)

# Print the shapes of x_data and y_data
print("x_data shape:", xdata.shape)
print("y_data shape:", ydata.shape)


# Convert numpy arrays to PyTorch tensors
xdata = torch.FloatTensor(xdata)
ydata = torch.FloatTensor(ydata)

# Combine xdata and ydata into a tuple for each data point
dataset = list(zip(xdata, ydata))

# Define the custom dataset class
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x, y = self.data[index]
        return x, y

# Split the dataset into training and testing sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

print("train size = %d, test size = %d" % (train_size, test_size))

class Pathway(nn.Module):
    def __init__(self, proto_neural_pathway, proto_neural_pathway_output, hidden_size=10):
        super(Pathway, self).__init__()


        # Shallow Network for Discovering Prototypes
        self.MLP = proto_neural_pathway
        self.MLP_output = proto_neural_pathway_output

        # Network Extensions
        self.MLP_extension = nn.Sequential(
            nn.PReLU(init=1.0),
            #nn.Linear(input_size, hidden_size),
            #nn.PReLU(init=1.0),
            #nn.Linear(hidden_size,hidden_size),
            #nn.PReLU(init=1.0),
            nn.Linear(hidden_size,hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(hidden_size,hidden_size),
            #nn.PReLU(init=1.0),
            #nn.Linear(hidden_size,output_size),
        )

        # Initialize pathway almost to identity, but add Gaussian noise to allow gradients to flow
        for m in self.MLP_extension:

          if isinstance(m, nn.Linear):

            # Initialize weight matrix to identity
            torch.nn.init.eye_(m.weight)

            # Initialize bias to zeros
            torch.nn.init.zeros_(m.bias)

            # Add a bit of noise to the weights
            noise_weight = 1e-2 * torch.randn_like(m.weight)
            m.weight.data.add_(noise_weight)

            # Add a bit of noise to the bias
            noise_bias = 1e-2 * torch.randn_like(m.bias)
            m.bias.data.add_(noise_bias)

    def forward(self,x):

        # Pass to pretrained MLP
        z = self.MLP(x)
        # Pass to new MLP extension
        y = self.MLP_output(self.MLP_extension(z))

        #print(self.MLP_extension[0].weight)

        return y
        
        
class Neural_Pathways(nn.Module):
    def __init__(self,
                 input_size=2,
                 hidden_size=50,
                 output_size=1,
                 num_mlps=3,
                 init_offset=0.05,
                 x_min = 0,
                 x_max = 1
                 ):

        super(Neural_Pathways, self).__init__()

        # Shallow Networks for Discovering Prototypes
        self.mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_size, hidden_size),
                nn.PReLU(),
                #nn.Linear(hidden_size, hidden_size),
                #nn.PReLU(),
                #nn.Linear(hidden_size, hidden_size),
                #nn.PReLU(),
                #nn.Linear(hidden_size, hidden_size),
                #nn.PReLU(),
                #nn.Linear(hidden_size, hidden_size),
                #nn.PReLU(),
                nn.Linear(hidden_size, hidden_size),
                #nn.PReLU(),
                #nn.Linear(hidden_size, hidden_size)
            ) for _ in range(num_mlps)
        ])

        

        self.mlps_output = nn.ModuleList([
            nn.Sequential(
                nn.PReLU(),      
                nn.Linear(hidden_size, output_size)
            ) for _ in range(num_mlps)
        ])
        
        # Prototypes (we initialize them evenly distributed)
        #prototype_initializations = torch.linspace(0 + init_offset, 1 - init_offset, steps=num_mlps)

        # Prototypes
        self.prototypes = nn.ParameterList([
            nn.Parameter(torch.rand(size=(1, input_size)) * (x_max - x_min) + x_min ) for i in range(num_mlps)
        ])

    def distances_to_prototypes(self,x):

        # Calculate square Euclidean distances for all prototypes
        distances = torch.cat([torch.cdist(x, prototype) for prototype in self.prototypes], axis=1)

        return distances


    def predict_discover_prototypes(self, x, inv_temperature = 0.01):

        # Get distances
        distances = self.distances_to_prototypes(x)

        # Calculate softmax weightings based on distances (- because we want to assign higher probabilities to nearby prototypes)
        weightings = F.softmax(-distances * inv_temperature, dim=1)

        # Get predictions from all MLPs
        #predictions = torch.cat([mlp(x) for mlp in self.mlps], axis=1)
        predictions = torch.cat([self.mlps_output[i](self.mlps[i](x)) for i in range(num_mlps)], axis=1)

        # Combine predictions with weightings
        total_prediction = torch.sum(weightings * predictions, dim=1)

        return total_prediction

    def calculate_W(self):

        # Stack the prototypes into a single matrix
        prototypes = torch.stack([prototype[0] for prototype in self.prototypes])

        # Calculate the pairwise Euclidean distances
        W = torch.cdist(prototypes, prototypes)

        return W

    def W_penalty(self):

        # Get W
        W = self.calculate_W()

        # Compute norm
        frobenius_norm = torch.norm(W)

        # Compute penalty
        penalty = torch.exp(-frobenius_norm)

        return penalty
        
        
print("starting training")

# Specify number of neural pathways
num_mlps = 4

hidden_layer_size = 1000

# Create data loaders for training and testing
batch_size = len(train_dataset)  # Adjust as needed
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Standard loss for regression tasks
criterion = nn.MSELoss()

# Training loop epochs
num_epochs = 1000

# Initialize Neural Pathways class for learning prototypes
neural_pathways = Neural_Pathways(input_size = 3, hidden_size=hidden_layer_size, output_size=1, num_mlps=num_mlps, x_min= min_x, x_max=max_x).to(device)

# Optimizer
optimizer = optim.Adam(neural_pathways.parameters(), lr=1e-4)
#optimizer = optim.Adam(neural_pathways.prototypes.parameters(), lr=1e-4)

# W penalty term (Imposes a prior that evenly distributed prototypes are better, which may not necessarily be the case)
gamma = 0.0 #0.001

# Initialize lists to store training metrics and prototype locations
training_loss_list = []
total_loss_list = []
w_penalty_list = []

prototype_locations = {}

# Store initial prototype locations
for prototype_idx in range(num_mlps):
    #print(neural_pathways.prototypes[prototype_idx][0])
    prototype_locations[f'prototype_{prototype_idx}_location'] = [neural_pathways.prototypes[prototype_idx][0]]


# Whether to clip prototype to be within x range (in this particular case between 0 and 1)
clip_in_range = False


# Prototypes training loop
#for epoch in range(num_epochs):
for epoch in range(1000):
    for inputs, targets in train_loader:
        inputs = inputs.to(device)
        optimizer.zero_grad()

        # Forward pass
        outputs = neural_pathways.predict_discover_prototypes(inputs)

        # Compute W penalty
        W_penalty = neural_pathways.W_penalty()

        # Compute loss
        #print(outputs.shape)
        #print(targets.shape)
        
        loss = criterion(outputs, targets.to(device).view(-1))
        total_loss = loss + gamma * W_penalty

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Clip prototypes to be within range
        if clip_in_range:
          for prototype_idx in range(num_mlps):
            neural_pathways.prototypes[prototype_idx].data = torch.clip(neural_pathways.prototypes[prototype_idx].data, 0, 1)

        # Store training metrics
        training_loss_list.append(loss.item())
        total_loss_list.append(total_loss.item())
        w_penalty_list.append(gamma * W_penalty.item())

        # Store prototype locations
        for prototype_idx in range(num_mlps):
          prototype_locations[f'prototype_{prototype_idx}_location'].append(neural_pathways.prototypes[prototype_idx])
    if not (epoch % 100):
        print(epoch,loss)
print("training done")


for prototype_idx in range(num_mlps):
    print(neural_pathways.prototypes[prototype_idx].detach().cpu().numpy()[0])
    
    

infTensor = torch.tensor(float('inf')).to(device)

# Get W matrix after training prototypes
W = neural_pathways.calculate_W()
print(W)

if num_mlps == 1:
  radii = [infTensor]
else:

  # Extract each row
  rows = W.unbind(0)

  # Find the minimum value in each row that is not equal to 0
  min_values = torch.tensor([row[row.nonzero(as_tuple=True)[0]].min() for row in rows])

  # Store these minimum values in a new tensor
  radii = min_values.unsqueeze(1)  # Make it a column tensor
print(radii)    

# Standard loss for regression tasks
criterion = nn.MSELoss()

# Training loop epochs
num_epochs = nb_iterations

# Lambda is a hyperparameter controlling the importance of points outside the radius during training
lambd = 0.0 # 0.01

now = datetime.now()
directory = "Model_3d_" + now.strftime("%Y_%m_%d_%H_%M_%S")
os.mkdir(directory)



prototype_parameters = neural_pathways.prototypes
prototypes = torch.zeros(num_mlps, 3).to(device)
for i in range(num_mlps):
  prototypes[i,:] = prototype_parameters[i].view(-1)

for i in range(num_mlps):
  print("leaf %d/%d" % (i+1, num_mlps))
  pathway1 = Pathway(proto_neural_pathway = neural_pathways.mlps[i], proto_neural_pathway_output = neural_pathways.mlps_output[i], hidden_size=hidden_layer_size).to(device)
  #r1 = radii[i].item()
  #prototype1 = neural_pathways.prototypes[i]

  # Create data loaders for training and testing
  batch_size = int(len(train_dataset))  # Adjust as needed
  train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

  # Optimizer
  optimizer = optim.Adam(pathway1.parameters(), lr=step_size)

    # Initialize lists to store training metrics and prototype locations
  training_loss_pathway = []
  training_loss_inside_radius = []
  training_loss_outside_radius = []

    # Neural Pathway training loop
  #for epoch in tqdm(range(num_epochs)):
  for epoch in tqdm(range(nb_iterations)):
    for inputs, targets in train_loader:

      optimizer.zero_grad()
            # Match dimensions
      targets = targets.to(device) #.unsqueeze(1)
      inputs = inputs.to(device) #.unsqueeze(1)

            # Forward pass
      #print(inputs.shape)
      outputs = pathway1.forward(inputs)
      #print(outputs.shape)

            # Compute distance between input location and prototype for neural pathway
      distance_to_prototype = torch.cdist(inputs,prototypes).clone().detach()
      #print(distance_to_prototype.shape)
      mask_inside_radius = torch.argmin(distance_to_prototype, dim=1) == i  
      #print(mask_inside_radius.shape)
      #print(mask_inside_radius)

            # Create boolean masks based on the distance criterion
      #mask_inside_radius = distance_to_prototype < r1
      #print(mask_inside_radius.shape)

            # Use boolean indexing to split predictions
      outputs_inside_radius = outputs[mask_inside_radius]
      targets_inside_radius = targets[mask_inside_radius]
      loss_inside_radius = criterion(outputs_inside_radius, targets_inside_radius)

      loss = loss_inside_radius
        
      if lambd:
        outputs_outside_radius = outputs[~mask_inside_radius]
        targets_outside_radius = targets[~mask_inside_radius]
        #print(targets_outside_radius.shape)
        #print(outputs_outside_radius.shape)
        with torch.no_grad():

              # Calculate square Euclidean distances to all prototypes
          distances = neural_pathways.distances_to_prototypes(inputs.detach())
          distances = torch.norm(distances, p=2, dim=1).unsqueeze(1)

              # Find the minimum value excluding zeros
        #print(W[i])
          find_minimum_of_tensor = distances*W[i] # Note we are taking row 0 here because we are using neural pathway 1, this needs to change for every neural pathway
          #print(find_minimum_of_tensor.shape)
          min_value, _ = torch.min(torch.where(find_minimum_of_tensor == 0, infTensor, find_minimum_of_tensor), dim=1)
        #print(min_value)

          # Compute penalty weights
          penalty_weight = torch.exp(-min_value)[~mask_inside_radius]


        # Compute loss

        # Compute error for each input, multiply times its distance based weight, take mean, and multiply times hyperparameter lambda
      #if lambd:
        loss_outside_radius = lambd*torch.mean(penalty_weight*(((outputs_outside_radius-targets_outside_radius)**2)))
        loss = loss + loss_outside_radius
      #else:
      #  loss_outside_radius = 0.0

        # Total loss for the neural pathway
       # + loss_outside_radius

        # Backward pass and optimization

      if not (epoch % 200):
        print(epoch,loss)        
        
      loss.backward()
      optimizer.step()
    #print(epoch,loss)    
  print(loss)
        # Store training metrics
    #training_loss_pathway.append(loss.item())
    #training_loss_inside_radius.append(loss_inside_radius.item())

    #if lambd:      
    #  training_loss_outside_radius.append(loss_outside_radius.item())

  torch.save(pathway1.state_dict(), os.path.join(directory, "model_%d.pt" % i))
print("training done")

# Initialize dataloader for test set
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Criterion as before
#criterion = nn.MSELoss()

criterion = nn.MSELoss(reduction='sum')

# Initialize loss
total_loss = 0.0

# Extract prototypes from neural pathways class
prototype_parameters = neural_pathways.prototypes
prototypes = torch.zeros(num_mlps, 3).to(device)
for i in range(num_mlps):
  prototypes[i,:] = prototype_parameters[i].view(-1)

# Test loop: loop over all pathways
for i in range(num_mlps):
  print(i)

  # Get pathway architecture
  pathway = Pathway(proto_neural_pathway = neural_pathways.mlps[i], proto_neural_pathway_output = neural_pathways.mlps_output[i],hidden_size=hidden_layer_size,).to(device)
  # Load pathway i
  pathway.load_state_dict(torch.load(os.path.join(directory, "model_%d.pt" % i)))
  pathway.eval()

  # Loop over dataloader
  for inputs, targets in test_loader:

    # Move inputs and targets to device
    targets = targets.to(device) #.view(-1,1)
    inputs = inputs.to(device) #.view(-1,1)
    #print(inputs.shape)
    #print(targets.shape)

    # Get masks based on input distance to prototypes
    distance_to_prototype = torch.cdist(inputs,prototypes).clone()
    mask_inside_radius = torch.argmin(distance_to_prototype, dim=1) == i

    # Only pass as input to the model data close to the prototype to avoid unnecessary computations
    outputs_inside_radius = pathway(inputs[mask_inside_radius])
    # Mask targets
    targets_inside_radius = targets[mask_inside_radius]
    # Compute loss
    loss_inside_radius = criterion(outputs_inside_radius, targets_inside_radius)

    # Add to total loss
    total_loss = total_loss + loss_inside_radius

# Note that we need to divide by the number of neural pathways, we are computing the average
#total_loss = total_loss #/num_mlps

total_loss_ours = total_loss

print("total loss = {:.15f}".format(total_loss))



class PathwayTest(nn.Module):
    def __init__(self, num_mlps, input_size = 1, hidden_size=10, output_size=1):
        super(PathwayTest, self).__init__()

        number_of_leaves = math.ceil(math.sqrt(num_mlps))
        # Shallow Network for Discovering Prototypes
        self.mlps = nn.Sequential(
                nn.Linear(input_size, number_of_leaves * hidden_size),
                nn.PReLU(init=1.0),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(init=1.0),
                nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(init=1.0),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * output_size),

            )


        # Network Extensions
        self.MLP_extension = nn.Sequential(
            #nn.PReLU(init=1.0),
            #nn.Linear(number_of_leaves * output_size, number_of_leaves * hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(number_of_leaves * hidden_size,number_of_leaves * hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(number_of_leaves * hidden_size,number_of_leaves * hidden_size),
            #nn.PReLU(init=1.0),
            #nn.Linear(number_of_leaves * hidden_size,number_of_leaves * hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(number_of_leaves * hidden_size,output_size),
        )

        # Automatically initialize MLP extension to identity
    #    self.initialize_extension()

    # Identity initialization using I for linear layer and no bias
    #def initialize_identity(self,m):
        #print("m in MLP extension")
        #for m in self.mlps:

        #  if isinstance(m, nn.Linear):
        #    torch.nn.init.kaiming_normal_(m.weight)
            #torch.nn.init.normal_(m.weight, std = 0.1)
            #print(m.weight.shape)
            #m.weight.data = m.weight.data + torch.eye(m.weight.shape[0], m.weight.shape[1])
            #print(m.weight)
        #    if m.bias is not None:
        #        torch.nn.init.zeros_(m.bias)

        #for m in self.MLP_extension:

        #  if isinstance(m, nn.Linear):
        #    torch.nn.init.kaiming_normal_(m.weight)
            #torch.nn.init.normal_(m.weight, std = 0.1)
            #print(m.weight.shape)
            #m.weight.data = m.weight.data + torch.eye(m.weight.shape[0], m.weight.shape[1])
            #print(m.weight)
        #    if m.bias is not None:
        #        torch.nn.init.zeros_(m.bias)
          #else:

            #print(m.weight)
        #print("end")

    #def initialize_extension(self):
    #    self.initialize_identity(self.MLP_extension)

    def forward(self,x):

        # Pass to pretrained MLP
        z = self.mlps(x)
        # Pass to new MLP extension
        y = self.MLP_extension(z)
        #print(self.MLP_extension[0].weight)

        return y
        
        

# Standard loss for regression tasks
criterion = nn.MSELoss()

# Training loop epochs
num_epochs = nb_iterations

# Initialize Neural Pathways class for learning prototypes
neural_pathways_test = PathwayTest(input_size = 3, hidden_size=hidden_layer_size, output_size=1, num_mlps=num_mlps).to(device)

# Optimizer (Use parameter groups to optimize some parameters at different rates)
#optimizer = torch.optim.SGD([
#    {'params': [param for name, param in neural_pathways.named_parameters() if 'prototypes' not in name], 'lr': 1e-3},  # Default learning rate for most parameters
#    {'params': [neural_pathways.prototypes[prototype_idx] for prototype_idx in range(num_mlps)], 'lr': 1e-2}  # Different learning rate for specific parameters
#], momentum=0.9)

optimizer = torch.optim.Adam([
    {'params': [param for name, param in neural_pathways_test.named_parameters() if 'prototypes' not in name], 'lr': step_size} # Different learning rate for specific parameters
])


training_loss_list = []
total_loss_list = []
w_penalty_list = []

# Prototypes training loop
for epoch in tqdm(range(num_epochs)):
    for inputs, targets in train_loader:

        optimizer.zero_grad()
        # Forward pass
        outputs = neural_pathways_test.forward(inputs.to(device))

        # Compute W penalty


        # Compute loss
        loss = criterion(outputs.view(-1), targets.to(device).view(-1))
        total_loss = loss

        # Backward pass and optimization
        total_loss.backward()
        optimizer.step()

        # Store training metrics
        training_loss_list.append(loss.item())
        #total_loss_list.append(total_loss.item())


        #for prototype_idx in range(num_mlps):
        #  prototype_locations[f'prototype_{prototype_idx}_location'].append(list(neural_pathways.prototypes[prototype_idx].detach().cpu().numpy()[0]))


print(outputs)
print(targets)
print(total_loss)

print("end of training")

### testing on the test set
criterion = nn.MSELoss(reduction='sum')
#criterion = nn.MSELoss()
total_loss = 0.0
neural_pathways_test.eval()
for inputs, targets in test_loader:
  targets = targets.to(device).view(-1)
  inputs = inputs.to(device) #.view(-1,1)
  outputs = neural_pathways_test.forward(inputs).view(-1)

            # Compute distance between input location and prototype for neural pathway
  #distance_to_prototype = torch.cdist(inputs,prototype1).clone()

            # Create boolean masks based on the distance criterion
  #  mask_inside_radius = distance_to_prototype < r1
      #print(mask_inside_radius.shape)

            # Use boolean indexing to split predictions
  #  outputs_inside_radius = outputs[mask_inside_radius[:,0]]
    #outputs_outside_radius = outputs[~mask_inside_radius[:,0]]
   # targets_inside_radius = targets[mask_inside_radius[:,0]]
    #targets_outside_radius = targets[~mask_inside_radius[:,0]]


   # with torch.no_grad():

              # Calculate square Euclidean distances to all prototypes
  #    distances = neural_pathways.distances_to_prototypes(inputs.detach())
  #    distances = torch.norm(distances, p=2, dim=1).unsqueeze(1)

              # Find the minimum value excluding zeros
   #   find_minimum_of_tensor = distances*W[i] # Note we are taking row 0 here because we are using neural pathway 1, this needs to change for every neural pathway
  #    min_value, _ = torch.min(torch.where(find_minimum_of_tensor == 0, infTensor, find_minimum_of_tensor), dim=1)


          # Compute penalty weights
    # penalty_weight = torch.exp(-min_value)[~mask_inside_radius[:,0]]


        # Compute loss
  total_loss = total_loss + criterion(outputs, targets)
        # Compute error for each input, multiply times its distance based weight, take mean, and multiply times hyperparameter lambda
      #loss_outside_radius = lambd*torch.mean(penalty_weight*(((outputs_outside_radius-targets_outside_radius)**2)[:,0]))

        # Total loss for the neural pathway
      #loss = loss_inside_radius + loss_outside_radius
print("total loss = {:.15f}".format(total_loss))     

print(outputs)
print(targets)
print(total_loss)

total_loss_baseline = total_loss


now = datetime.now()
result_file = "results/results_" + now.strftime("%Y_%m_%d_%H_%M_%S.txt")
f = open(result_file, "w")
f.write("{:.15f}\n".format(total_loss_ours))
f.write("{:.15f}\n".format(total_loss_baseline))
f.close()   
