import pickle

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset


# Definition of the denoising autoencoder
class DenoisingAutoencoder(nn.Module):
    def __init__(self, input_dim, output_dim, mask_prob=0.1):
        super(DenoisingAutoencoder, self).__init__()

        self.mask_prob = mask_prob  # probability of masking input data

        # Encoder: Linear transformation followed by a ReLU activation
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU()
        )
        # Decoder: Linear transformation followed by a ReLU activation
        self.decoder = nn.Sequential(
            nn.Linear(output_dim, input_dim),
            nn.ReLU()
        )

    # Function to create a mask with the specified probability
    def mask(self, x):
        mask = torch.zeros_like(x).bernoulli_(1 - self.mask_prob)
        return x * mask

    # Forward propagation function
    def forward(self, x):
        if self.training:
            x = self.mask(x)  # Apply mask to input data if in training mode

        encoded = self.encoder(x)
        self.saved_encoded_coordinates = encoded
        decoded = self.decoder(encoded)
        return decoded


# Definition of the stacked denoising autoencoder
class SDAE(object):
    def __init__(self, dim_list, mask_prob=0.1, learning_rate=0.1, num_epochs=100):
        self.frames = []  # to store the layers of the stacked autoencoder
        self.frame_num = len(dim_list) - 1
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs

        # Check if GPU is available
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        # Initialize the layers of the stacked autoencoder
        for i in range(self.frame_num):
            model = DenoisingAutoencoder(dim_list[i], dim_list[i + 1], mask_prob)
            model.eval()
            self.frames.append(model.to(self.device))

    # Function to train the stacked autoencoder
    def train(self, train_dataloader):
        # Layer-wise training
        for frame_idx in range(self.frame_num):
            model = self.frames[frame_idx]
            model.train()
            criterion = MyLoss().to(self.device)
            optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)

            k1 = 1.0
            k2 = 1.0
            prev_loss1 = -1
            prev_loss2 = -1

            for epoch in range(self.num_epochs):
                running_loss = 0.0
                running_loss1 = 0.0
                running_loss2 = 0.0

                for coordinates, temperatures in train_dataloader:
                    coordinates = coordinates.to(self.device)
                    temperatures = temperatures.to(self.device)

                    last_coordinates = self._get_encoding(frame_idx, coordinates)
                    decoded_coordinates = model(last_coordinates)
                    loss, loss1, loss2 = criterion(
                        decoded_coordinates, last_coordinates,
                        model.saved_encoded_coordinates, temperatures,
                        k1, k2
                    )
                    running_loss += loss.item()
                    running_loss1 += loss1.item()
                    running_loss2 += loss2.item()

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                running_loss /= len(train_dataloader)
                running_loss1 /= len(train_dataloader)
                running_loss2 /= len(train_dataloader)

                # Adjust the coefficients of the loss function based on the previous loss
                if prev_loss1 > 0:
                    k1 = running_loss1 / prev_loss1
                    k2 = running_loss2 / prev_loss2
                prev_loss1 = running_loss1
                prev_loss2 = running_loss2

                print('Epoch [%d/%d], Loss: %.8f, Loss1: %.8f, Loss2: %.8f, k1=%.2f，k2=%.2f' % (
                    epoch + 1, self.num_epochs, running_loss, running_loss1, running_loss2, k1, k2))

            model.eval()

    def overall_loss(self, train_dataloader):
        criterion = MyLoss().to(self.device)

        running_loss = 0.0
        running_loss1 = 0.0
        running_loss2 = 0.0

        for coordinates, temperatures in train_dataloader:
            coordinates = coordinates.to(self.device)
            temperatures = temperatures.to(self.device)

            encoded_coordinates = self.encode(coordinates)
            decoded_coordinates = self.decode(encoded_coordinates)

            loss, loss1, loss2 = criterion(
                decoded_coordinates, coordinates,
                encoded_coordinates, temperatures,
                1.0, 1.0
            )
            running_loss += loss.item()
            running_loss1 += loss1.item()
            running_loss2 += loss2.item()

        running_loss /= len(train_dataloader)
        running_loss1 /= len(train_dataloader)
        running_loss2 /= len(train_dataloader)

        print('Overall Loss: %.8f, Loss1: %.8f, Loss2: %.8f' % (running_loss, running_loss1, running_loss2))
        return running_loss, running_loss1, running_loss2

    def _get_encoding(self, frame_idx, X):
        for i in range(frame_idx):
            X = self.frames[i].encoder(X)
        return X

    def encode(self, X):
        return self._get_encoding(self.frame_num, X)

    def decode(self, X):
        for i in range(self.frame_num - 1, -1, -1):
            X = self.frames[i].decoder(X)
        return X


# Definition of the custom loss function
class MyLoss(nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()
        self.mse_loss = nn.MSELoss()

    @staticmethod
    def minmax_norm(X):
        min_val = torch.min(X)
        max_val = torch.max(X)
        normalized_tensor = (X - min_val) / (max_val - min_val)
        return normalized_tensor

    def forward(self, decoded_coordinates, original_coordinates,
                encoded_coordinates, temperatures, k1, k2):
        loss1 = nn.MSELoss()(decoded_coordinates, original_coordinates)

        pairwise_temperature = torch.abs(temperatures.unsqueeze(1) - temperatures.unsqueeze(0))
        pairwise_distance = torch.cdist(encoded_coordinates, encoded_coordinates)

        # Used to check if it is 0
        eps = np.finfo(float).eps

        # Calculate pairwise temperature rate of change
        # (only calculate for pairwise distance not equal to 0, otherwise it will be NaN)
        pairwise_rate_of_change = torch.zeros_like(pairwise_temperature)
        mask = pairwise_distance > eps
        pairwise_rate_of_change[mask] = pairwise_temperature[mask] / pairwise_distance[mask]

        # Limit rate of change between 0 and 1
        pairwise_rate_of_change = self.minmax_norm(pairwise_rate_of_change)

        # Handle the case where the distance is 0,
        # if the distance is 0 but the temperature difference is not 0,
        # the rate of change is set to 1 (considered as the maximum)
        zero_distance_mask = pairwise_distance < eps
        non_zero_temperature = pairwise_temperature > eps
        pairwise_rate_of_change[zero_distance_mask & non_zero_temperature] = 1.0

        loss2 = (pairwise_rate_of_change * pairwise_rate_of_change).mean()

        return k1 / (k1 + k2) * loss1 + k2 / (k1 + k2) * loss2, loss1, loss2


if __name__ == '__main__':
    batch_size = 1024
    dim_list = [709, 256, 128, 64]

    shapley_value_path = 'tmc_shapley_results.pkl'
    coordinates_path = 'splitted_dataset.pkl'

    with open(shapley_value_path, 'rb') as f:
        shapley_values = pickle.load(f)[0]

    with open(coordinates_path, 'rb') as f:
        coordinates = pickle.load(f)[0]

    train_dataset = TensorDataset(
        torch.tensor(coordinates, dtype=torch.float32),
        torch.tensor(shapley_values, dtype=torch.float32)
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    sdae = SDAE(dim_list=dim_list, num_epochs=100, learning_rate=0.005, mask_prob=0.2)

    sdae.train(train_dataloader)

    with open('sdae_model.pkl', 'wb') as f:
        pickle.dump(sdae, f)
