import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from fbm import fbm, fgn, times
import matplotlib.pyplot as plt
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
import math
from datetime import datetime
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = %s" % device)

# Number of dataset samples
num_samples = 10000
step_size = 1e-5

# Get the times associated with the fBm
xdata = times(n=num_samples, length=1)


# Generate a fractional Brownian motion realization
#ydata = fbm(n=num_samples, hurst=0.9, length=1, method='daviesharte') # hurst controls the amount of randomness
ydata = fbm(n=num_samples, hurst=0.2, length=1, method='daviesharte') #+ np.cos(10*xdata) # hurst controls the amount of randomness


# Convert numpy arrays to PyTorch tensors
xdata = torch.FloatTensor(xdata)
ydata = torch.FloatTensor(ydata)



# Combine xdata and ydata into a tuple for each data point
dataset = list(zip(xdata, ydata))

# Define the custom dataset class
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x, y = self.data[index]
        return x, y

# Split the dataset into training and testing sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

print("train size = %d, test size = %d" % (train_size, test_size))

# Extract data from the datasets
train_data = list(train_dataset)
test_data = list(test_dataset)

# Unpack the data for plotting
train_x, train_y = zip(*train_data)
test_x, test_y = zip(*test_data)


class Pathway(nn.Module):
    def __init__(self, proto_neural_pathway, proto_neural_pathway_output, hidden_size=10):
        super(Pathway, self).__init__()


        # Shallow Network for Discovering Prototypes
        self.MLP = proto_neural_pathway
        self.MLP_output = proto_neural_pathway_output

        # Network Extensions
        self.MLP_extension = nn.Sequential(
            nn.PReLU(init=1.0),
            #nn.Linear(input_size, hidden_size),
            #nn.PReLU(init=1.0),
            #nn.Linear(hidden_size,hidden_size),
            #nn.PReLU(init=1.0),
            nn.Linear(hidden_size,hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(hidden_size,hidden_size),
            #nn.PReLU(init=1.0),
            #nn.Linear(hidden_size,output_size),
        )

        # Initialize pathway almost to identity, but add Gaussian noise to allow gradients to flow
        for m in self.MLP_extension:

          if isinstance(m, nn.Linear):

            # Initialize weight matrix to identity
            torch.nn.init.eye_(m.weight)

            # Initialize bias to zeros
            torch.nn.init.zeros_(m.bias)

            # Add a bit of noise to the weights
            noise_weight = 1e-2 * torch.randn_like(m.weight)
            m.weight.data.add_(noise_weight)

            # Add a bit of noise to the bias
            noise_bias = 1e-2 * torch.randn_like(m.bias)
            m.bias.data.add_(noise_bias)

    def forward(self,x):

        # Pass to pretrained MLP
        z = self.MLP(x)
        # Pass to new MLP extension
        y = self.MLP_output(self.MLP_extension(z))

        #print(self.MLP_extension[0].weight)

        return y
        
class Neural_Pathways(nn.Module):
    def __init__(self,
                 input_size=1,
                 hidden_size=50,
                 output_size=1,
                 num_mlps=3,
                 init_offset=0.05
                 ):

        super(Neural_Pathways, self).__init__()

        # Shallow Networks for Discovering Prototypes
        self.mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_size, hidden_size),
                nn.PReLU(),
                #nn.Linear(hidden_size, hidden_size),
                #nn.PReLU(),
                #nn.Linear(hidden_size, hidden_size),
                #nn.PReLU(),
                #nn.Linear(hidden_size, hidden_size),
                #nn.PReLU(),
                #nn.Linear(hidden_size, hidden_size),
                #nn.PReLU(),
                nn.Linear(hidden_size, hidden_size),
                #nn.PReLU(),
                #nn.Linear(hidden_size, hidden_size)
            ) for _ in range(num_mlps)
        ])

        

        self.mlps_output = nn.ModuleList([
            nn.Sequential(
                nn.PReLU(),      
                nn.Linear(hidden_size, output_size)
            ) for _ in range(num_mlps)
        ])
        
        # Prototypes (we initialize them evenly distributed)
        prototype_initializations = torch.linspace(0 + init_offset, 1 - init_offset, steps=num_mlps)
        self.prototypes = nn.ParameterList([
            nn.Parameter(torch.ones(size=(1, input_size)).fill_(prototype_initializations[i])) for i in range(num_mlps)
        ])

    def distances_to_prototypes(self,x):

        # Calculate square Euclidean distances for all prototypes
        distances = torch.cat([torch.cdist(x, prototype) for prototype in self.prototypes], axis=1)

        return distances


    def predict_discover_prototypes(self, x, inv_temperature = 0.01):

        # Get distances
        distances = self.distances_to_prototypes(x)

        # Calculate softmax weightings based on distances (- because we want to assign higher probabilities to nearby prototypes)
        weightings = F.softmax(-distances * inv_temperature, dim=1)

        # Get predictions from all MLPs
        #predictions = torch.cat([mlp(x) for mlp in self.mlps], axis=1)
        predictions = torch.cat([self.mlps_output[i](self.mlps[i](x)) for i in range(num_mlps)], axis=1)

        # Combine predictions with weightings
        total_prediction = torch.sum(weightings * predictions, dim=1)

        return total_prediction

    def calculate_W(self):

        # Stack the prototypes into a single matrix
        prototypes = torch.stack([prototype[0] for prototype in self.prototypes])

        # Calculate the pairwise Euclidean distances
        W = torch.cdist(prototypes, prototypes)

        return W

    def W_penalty(self):

        # Get W
        W = self.calculate_W()

        # Compute norm
        frobenius_norm = torch.norm(W)

        # Compute penalty
        penalty = torch.exp(-frobenius_norm)

        return penalty
        
        
print("starting training")

# Specify number of neural pathways
num_mlps = 4

hidden_layer_size = 1000

# Create data loaders for training and testing
batch_size = len(train_dataset)  # Adjust as needed
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Standard loss for regression tasks
criterion = nn.MSELoss()

# Training loop epochs
num_epochs = 1000

# Initialize Neural Pathways class for learning prototypes
neural_pathways = Neural_Pathways(input_size = 1, hidden_size=hidden_layer_size, output_size=1, num_mlps=num_mlps).to(device)

# Optimizer
optimizer = optim.Adam(neural_pathways.parameters(), lr=1e-4)
#optimizer = optim.Adam(neural_pathways.prototypes.parameters(), lr=1e-4)

# W penalty term (Imposes a prior that evenly distributed prototypes are better, which may not necessarily be the case)
gamma = 0.0#0.001

# Initialize lists to store training metrics and prototype locations
training_loss_list = []
total_loss_list = []
w_penalty_list = []

prototype_locations = {}

# Store initial prototype locations
for prototype_idx in range(num_mlps):
    prototype_locations[f'prototype_{prototype_idx}_location'] = [neural_pathways.prototypes[prototype_idx].item()]


# Whether to clip prototype to be within x range (in this particular case between 0 and 1)
clip_in_range = False


# Prototypes training loop
for epoch in range(num_epochs):
    for inputs, targets in train_loader:
        inputs = inputs.to(device)
        optimizer.zero_grad()

        # Forward pass
        outputs = neural_pathways.predict_discover_prototypes(inputs.unsqueeze(1))

        # Compute W penalty
        W_penalty = neural_pathways.W_penalty()

        # Compute loss
        loss = criterion(outputs, targets.to(device))
        total_loss = loss + gamma * W_penalty

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Clip prototypes to be within range
        if clip_in_range:
          for prototype_idx in range(num_mlps):
            neural_pathways.prototypes[prototype_idx].data = torch.clip(neural_pathways.prototypes[prototype_idx].data, 0, 1)

        # Store training metrics
        training_loss_list.append(loss.item())
        total_loss_list.append(total_loss.item())
        w_penalty_list.append(gamma * W_penalty.item())

        # Store prototype locations
        for prototype_idx in range(num_mlps):
          prototype_locations[f'prototype_{prototype_idx}_location'].append(neural_pathways.prototypes[prototype_idx].item())
    if not (epoch % 100):
        print(epoch,loss)
print("training done")        



infTensor = torch.tensor(float('inf')).to(device)

# Get W matrix after training prototypes
W = neural_pathways.calculate_W()
print(W)

if num_mlps == 1:
  radii = [infTensor]
else:

  # Extract each row
  rows = W.unbind(0)

  # Find the minimum value in each row that is not equal to 0
  min_values = torch.tensor([row[row.nonzero(as_tuple=True)[0]].min() for row in rows])

  # Store these minimum values in a new tensor
  radii = min_values.unsqueeze(1)  # Make it a column tensor
print(radii)


# Standard loss for regression tasks
criterion = nn.MSELoss()

# Training loop epochs
num_epochs = 10000

# Lambda is a hyperparameter controlling the importance of points outside the radius during training
lambd = 0.0 # 0.01

now = datetime.now()
directory = "Model_1d_" + now.strftime("%Y_%m_%d_%H_%M_%S")
os.mkdir(directory)



prototype_parameters = neural_pathways.prototypes
prototypes = torch.zeros(num_mlps, 1).to(device)
for i in range(num_mlps):
  prototypes[i,:] = prototype_parameters[i].view(-1)

for i in range(num_mlps):
  print("leaf %d/%d" % (i+1, num_mlps))
  pathway1 = Pathway(proto_neural_pathway = neural_pathways.mlps[i], proto_neural_pathway_output = neural_pathways.mlps_output[i], hidden_size=hidden_layer_size).to(device)
  #r1 = radii[i].item()
  prototype1 = neural_pathways.prototypes[i]

  # Create data loaders for training and testing
  batch_size = int(len(train_dataset))  # Adjust as needed
  train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

  # Optimizer
  optimizer = optim.Adam(pathway1.parameters(), lr=step_size)

    # Initialize lists to store training metrics and prototype locations
  training_loss_pathway = []
  training_loss_inside_radius = []
  training_loss_outside_radius = []

    # Neural Pathway training loop
  for epoch in tqdm(range(num_epochs)):
    for inputs, targets in train_loader:

      optimizer.zero_grad()
            # Match dimensions
      targets = targets.to(device).unsqueeze(1)
      inputs = inputs.to(device).unsqueeze(1)

            # Forward pass
      #print(inputs.shape)
      outputs = pathway1.forward(inputs)
      #print(outputs.shape)

            # Compute distance between input location and prototype for neural pathway
      distance_to_prototype = torch.cdist(inputs,prototypes).clone()

      mask_inside_radius = torch.argmin(distance_to_prototype, dim=1) == i   
      #print(mask_inside_radius)

            # Create boolean masks based on the distance criterion
      #mask_inside_radius = distance_to_prototype < r1
      #print(mask_inside_radius.shape)

            # Use boolean indexing to split predictions
      outputs_inside_radius = outputs[mask_inside_radius]
      targets_inside_radius = targets[mask_inside_radius]
      loss_inside_radius = criterion(outputs_inside_radius, targets_inside_radius)

      loss = loss_inside_radius
        
      if lambd:
        outputs_outside_radius = outputs[~mask_inside_radius]
        targets_outside_radius = targets[~mask_inside_radius]
        #print(targets_outside_radius.shape)
        #print(outputs_outside_radius.shape)
        with torch.no_grad():

              # Calculate square Euclidean distances to all prototypes
          distances = neural_pathways.distances_to_prototypes(inputs.detach())
          distances = torch.norm(distances, p=2, dim=1).unsqueeze(1)

              # Find the minimum value excluding zeros
        #print(W[i])
          find_minimum_of_tensor = distances*W[i] # Note we are taking row 0 here because we are using neural pathway 1, this needs to change for every neural pathway
          #print(find_minimum_of_tensor.shape)
          min_value, _ = torch.min(torch.where(find_minimum_of_tensor == 0, infTensor, find_minimum_of_tensor), dim=1)
        #print(min_value)

          # Compute penalty weights
          penalty_weight = torch.exp(-min_value)[~mask_inside_radius]


        # Compute loss

        # Compute error for each input, multiply times its distance based weight, take mean, and multiply times hyperparameter lambda
      #if lambd:
        loss_outside_radius = lambd*torch.mean(penalty_weight*(((outputs_outside_radius-targets_outside_radius)**2)))
        loss = loss + loss_outside_radius
      #else:
      #  loss_outside_radius = 0.0

        # Total loss for the neural pathway
       # + loss_outside_radius

        # Backward pass and optimization
      loss.backward()
      optimizer.step()

        # Store training metrics
    training_loss_pathway.append(loss.item())
    training_loss_inside_radius.append(loss_inside_radius.item())

    if lambd:      
      training_loss_outside_radius.append(loss_outside_radius.item())

  torch.save(pathway1.state_dict(), os.path.join(directory, "model_%d.pt" % i))
print("training done")

# Initialize dataloader for test set
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Criterion as before
#criterion = nn.MSELoss()

criterion = nn.MSELoss(reduction='sum')

# Initialize loss
total_loss = 0.0

# Extract prototypes from neural pathways class
prototype_parameters = neural_pathways.prototypes
prototypes = torch.zeros(num_mlps, 1).to(device)
for i in range(num_mlps):
  prototypes[i,:] = prototype_parameters[i].view(-1)

# Test loop: loop over all pathways
for i in range(num_mlps):

  # Get pathway architecture
  pathway = Pathway(proto_neural_pathway = neural_pathways.mlps[i], proto_neural_pathway_output = neural_pathways.mlps_output[i],hidden_size=hidden_layer_size,).to(device)
  # Load pathway i
  pathway.load_state_dict(torch.load(os.path.join(directory, "model_%d.pt" % i)))
  pathway.eval()

  # Loop over dataloader
  for inputs, targets in test_loader:

    # Move inputs and targets to device
    targets = targets.to(device).view(-1,1)
    inputs = inputs.to(device).view(-1,1)
    #print(inputs.shape)
    #print(targets.shape)

    # Get masks based on input distance to prototypes
    distance_to_prototype = torch.cdist(inputs,prototypes).clone()
    mask_inside_radius = torch.argmin(distance_to_prototype, dim=1) == i

    # Only pass as input to the model data close to the prototype to avoid unnecessary computations
    outputs_inside_radius = pathway(inputs[mask_inside_radius])
    # Mask targets
    targets_inside_radius = targets[mask_inside_radius]
    # Compute loss
    loss_inside_radius = criterion(outputs_inside_radius, targets_inside_radius)

    # Add to total loss
    total_loss = total_loss + loss_inside_radius

# Note that we need to divide by the number of neural pathways, we are computing the average
total_loss_ours = total_loss #/num_mlps

print("total loss = {:.15f}".format(total_loss))


class PathwayTest(nn.Module):
    def __init__(self, num_mlps, input_size = 1, hidden_size=10, output_size=1):
        super(PathwayTest, self).__init__()

        number_of_leaves = math.ceil(math.sqrt(num_mlps))
        # Shallow Network for Discovering Prototypes
        self.mlps = nn.Sequential(
                nn.Linear(input_size, number_of_leaves * hidden_size),
                nn.PReLU(init=1.0),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(init=1.0),
                nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(init=1.0),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * hidden_size),
                #nn.PReLU(),
                #nn.Linear(number_of_leaves * hidden_size, number_of_leaves * output_size),

            )


        # Network Extensions
        self.MLP_extension = nn.Sequential(
            #nn.PReLU(init=1.0),
            #nn.Linear(number_of_leaves * output_size, number_of_leaves * hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(number_of_leaves * hidden_size,number_of_leaves * hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(number_of_leaves * hidden_size,number_of_leaves * hidden_size),
            #nn.PReLU(init=1.0),
            #nn.Linear(number_of_leaves * hidden_size,number_of_leaves * hidden_size),
            nn.PReLU(init=1.0),
            nn.Linear(number_of_leaves * hidden_size,output_size),
        )

        # Automatically initialize MLP extension to identity
    #    self.initialize_extension()

    # Identity initialization using I for linear layer and no bias
    #def initialize_identity(self,m):
        #print("m in MLP extension")
        #for m in self.mlps:

        #  if isinstance(m, nn.Linear):
        #    torch.nn.init.kaiming_normal_(m.weight)
            #torch.nn.init.normal_(m.weight, std = 0.1)
            #print(m.weight.shape)
            #m.weight.data = m.weight.data + torch.eye(m.weight.shape[0], m.weight.shape[1])
            #print(m.weight)
        #    if m.bias is not None:
        #        torch.nn.init.zeros_(m.bias)

        #for m in self.MLP_extension:

        #  if isinstance(m, nn.Linear):
        #    torch.nn.init.kaiming_normal_(m.weight)
            #torch.nn.init.normal_(m.weight, std = 0.1)
            #print(m.weight.shape)
            #m.weight.data = m.weight.data + torch.eye(m.weight.shape[0], m.weight.shape[1])
            #print(m.weight)
        #    if m.bias is not None:
        #        torch.nn.init.zeros_(m.bias)
          #else:

            #print(m.weight)
        #print("end")

    #def initialize_extension(self):
    #    self.initialize_identity(self.MLP_extension)

    def forward(self,x):

        # Pass to pretrained MLP
        z = self.mlps(x)
        # Pass to new MLP extension
        y = self.MLP_extension(z)
        #print(self.MLP_extension[0].weight)

        return y



# Standard loss for regression tasks
criterion = nn.MSELoss()

# Training loop epochs
num_epochs = 10000

# Initialize Neural Pathways class for learning prototypes
neural_pathways_test = PathwayTest(input_size = 1, hidden_size=hidden_layer_size, output_size=1, num_mlps=num_mlps).to(device)

# Optimizer (Use parameter groups to optimize some parameters at different rates)
#optimizer = torch.optim.SGD([
#    {'params': [param for name, param in neural_pathways.named_parameters() if 'prototypes' not in name], 'lr': 1e-3},  # Default learning rate for most parameters
#    {'params': [neural_pathways.prototypes[prototype_idx] for prototype_idx in range(num_mlps)], 'lr': 1e-2}  # Different learning rate for specific parameters
#], momentum=0.9)

optimizer = torch.optim.Adam([
    {'params': [param for name, param in neural_pathways_test.named_parameters() if 'prototypes' not in name], 'lr': step_size} # Different learning rate for specific parameters
])


training_loss_list = []
total_loss_list = []
w_penalty_list = []

# Prototypes training loop
for epoch in tqdm(range(num_epochs)):
    for inputs, targets in train_loader:

        optimizer.zero_grad()
        # Forward pass
        outputs = neural_pathways_test.forward(inputs.to(device).view(-1,1))

        # Compute W penalty


        # Compute loss
        loss = criterion(outputs.view(-1), targets.to(device).view(-1))
        total_loss = loss

        # Backward pass and optimization
        total_loss.backward()
        optimizer.step()

        # Store training metrics
        training_loss_list.append(loss.item())
        #total_loss_list.append(total_loss.item())


        #for prototype_idx in range(num_mlps):
        #  prototype_locations[f'prototype_{prototype_idx}_location'].append(list(neural_pathways.prototypes[prototype_idx].detach().cpu().numpy()[0]))


### testing on the test set
criterion = nn.MSELoss(reduction='sum')
#criterion = nn.MSELoss()
total_loss = 0.0
neural_pathways_test.eval()
for inputs, targets in test_loader:
  targets = targets.to(device).view(-1)
  inputs = inputs.to(device).view(-1,1)
  outputs = neural_pathways_test.forward(inputs).view(-1)

            # Compute distance between input location and prototype for neural pathway
  #distance_to_prototype = torch.cdist(inputs,prototype1).clone()

            # Create boolean masks based on the distance criterion
  #  mask_inside_radius = distance_to_prototype < r1
      #print(mask_inside_radius.shape)

            # Use boolean indexing to split predictions
  #  outputs_inside_radius = outputs[mask_inside_radius[:,0]]
    #outputs_outside_radius = outputs[~mask_inside_radius[:,0]]
   # targets_inside_radius = targets[mask_inside_radius[:,0]]
    #targets_outside_radius = targets[~mask_inside_radius[:,0]]


   # with torch.no_grad():

              # Calculate square Euclidean distances to all prototypes
  #    distances = neural_pathways.distances_to_prototypes(inputs.detach())
  #    distances = torch.norm(distances, p=2, dim=1).unsqueeze(1)

              # Find the minimum value excluding zeros
   #   find_minimum_of_tensor = distances*W[i] # Note we are taking row 0 here because we are using neural pathway 1, this needs to change for every neural pathway
  #    min_value, _ = torch.min(torch.where(find_minimum_of_tensor == 0, infTensor, find_minimum_of_tensor), dim=1)


          # Compute penalty weights
    # penalty_weight = torch.exp(-min_value)[~mask_inside_radius[:,0]]


        # Compute loss
  total_loss = total_loss + criterion(outputs, targets)
        # Compute error for each input, multiply times its distance based weight, take mean, and multiply times hyperparameter lambda
      #loss_outside_radius = lambd*torch.mean(penalty_weight*(((outputs_outside_radius-targets_outside_radius)**2)[:,0]))

        # Total loss for the neural pathway
      #loss = loss_inside_radius + loss_outside_radius
print("total loss = {:.15f}".format(total_loss))

total_loss_baseline = total_loss


now = datetime.now()
result_file = "results/results_1D_" + now.strftime("%Y_%m_%d_%H_%M_%S.txt")
f = open(result_file, "w")
f.write("{:.15f}\n".format(total_loss_ours))
f.write("{:.15f}\n".format(total_loss_baseline))
f.close()
