from pathlib import Path

import gc

import tqdm

import numpy as np

import torch
from torch import optim, nn

import lightning.pytorch as pl

from sklearn.metrics import (
    accuracy_score, 
    jaccard_score,
)

from .locationencoder import (
    get_positional_encoding, 
    get_neural_network_v2, 
    get_loss_fn, 
    get_param,
)

from .time_embedding_functions import (
    time_embedding_functions,
    time_embedding_functions_more_params,
)

# define the LightningModule
class LocationEncoder3D(pl.LightningModule):
    """
    A PyTorch Lightning module for encoding 3D locations using positional and time embeddings.
    """
    def build_artifact_path(self, legendre_polys, positional_encoder):
        """
        Build the artifact path for saving model artifacts.

        Args:
            legendre_polys (int): Number of Legendre polynomials.
            positional_encoder (PE_3D): Positional encoder instance.

        Returns:
            Path: The artifact path.
        """
        return ( Path(".") /
            f"legendre_{legendre_polys}" /
            positional_encoder.artifact_path
        )
    
    def __init__(self, 
                 positional_embedding_type, 
                 time_embedding_dim, 
                 harmonics_calculation,
                 legendre_polys,
                 combination_type, 
                 time_embedding_type, 
                 combined_encoding_args,
                 spatial_encoding_args,
                 time_encoding_args,
                 number_of_timesteps,
                 ortho_weight,
                 ortho_weight_space,
                 ortho_weight_time,
                 time_grad_penalty_weight,
                 normality_flag,
                 ortho_exponent, 
                 hparams,
                 ):
        """
        Initialize the LocationEncoder_3d module.

        Args:
            positional_embedding_type (str): Type of positional embedding.
            time_embedding_dim (int): Dimension of time embedding.
            harmonics_calculation (str): Harmonics calculation method.
            legendre_polys (int): Number of Legendre polynomials.
            combination_type (str): Type of combination for embeddings.
            time_embedding_type (str): Type of time embedding.
            combined_encoding_args (dict): Arguments for combined encoding.
            spatial_encoding_args (dict): Arguments for spatial encoding.
            time_encoding_args (dict): Arguments for time encoding.
            number_of_timesteps (int): Number of timesteps.
            ortho_weight (float): Orthogonality weight for final layer.
            ortho_weight_space (float): Orthogonality weight for space encoder.
            ortho_weight_time (float): Orthogonality weight for time encoder.
            hparams (dict): Hyperparameters.
        """
        super().__init__()

        self.neural_network = get_neural_network_v2(
            **combined_encoding_args
        )
        self.combined_encoding_args = combined_encoding_args

        self.learning_rate = hparams["optimizer"]["lr"]
        self.weight_decay = hparams["optimizer"]["wd"]
        hparams["legendre_polys"] = legendre_polys
        
        self.regression = get_param(hparams, "regression")
        self.positional_encoding_type = positional_embedding_type
        
        self.loss_fn = get_loss_fn(presence_only=get_param(hparams, "presence_only_loss"), 
                                   loss_weight=get_param(hparams, "loss_weight"),
                                   regression=self.regression)


        self.positional_encoder = PE_3D(
            space_encoder=get_positional_encoding(
                positional_embedding_type, hparams
            ),
            combination_type=combination_type,
            spatial_encoding_args=spatial_encoding_args,
            time_embedding_type=time_embedding_type,
            time_encoding_args=time_encoding_args,
            number_of_timesteps=number_of_timesteps,
        )
        self.artifact_path = self.build_artifact_path(legendre_polys, self.positional_encoder)

        self.size_accumulator = 0
        
        self.ortho_weight = ortho_weight
        self.ortho_weight_space = ortho_weight_space
        self.ortho_weight_time = ortho_weight_time

        self.time_grad_penalty_weight = time_grad_penalty_weight

        self.normality_flag = normality_flag

        self.ortho_exponent = ortho_exponent
        # this enables LocationEncoder.load_from_checkpoint(path)
        self.save_hyperparameters()

        self.out_dim, self.in_dim = self.combined_encoding_args["output_dim"], self.combined_encoding_args["input_dim"]

    def compute_time_grad(self, lat_lon_time):
        # compute & log the time-derivative norm
        lat_lon_time_time = lat_lon_time.clone().detach().requires_grad_(True)
        outputs_time = self.forward(lat_lon_time_time).sum()
        self.zero_grad()
        outputs_time.backward()
        time_grad = lat_lon_time_time.grad[:, 2]

        return time_grad

    def common_step(self, batch, batch_idx, mode):
        """
        Common step for training and validation.

        Args:
            batch (tuple): Batch of data.
            batch_idx (int): Batch index.

        Returns:
            Tensor: Loss value.
        """
        # run in every training step and validation step
        lat_lon_time, label = batch
        
        time_grad = self.compute_time_grad(lat_lon_time)
        time_grad_penalty = 0.0
        if time_grad is not None and time_grad.shape[0] > 0:
            time_grad_norm = time_grad.norm(p=2) / time_grad.shape[0]
            self.log(f"{mode}_avg_time_derivative_norm", time_grad_norm, on_step=True, on_epoch=True)
            time_grad_penalty = time_grad_norm**2

        return self.loss_fn(self, lat_lon_time, label) + self.time_grad_penalty_weight*time_grad_penalty

    def forward(self, lat_lon_time):
        """
        Forward pass of the model.

        Args:
            lat_lon_time (Tensor): Longitude and latitude coordinates.

        Returns:
            Tensor: Model output.
        """
        embedding = self.positional_encoder(lat_lon_time)

        return self.neural_network(embedding)
    
    def get_last_activations(self, lat_lon_time):
        """
        Get the last activations from the neural network.

        Args:
            lat_lon_time (Tensor): Longitude and latitude coordinates.

        Returns:
            Tensor: Last activations.
        """
        embedding = self.positional_encoder(lat_lon_time)

        return self.neural_network.get_last_activations(embedding)
    
    def compute_ortho_regularizer(self, lat_lon_time):
        x = lat_lon_time.clone().detach().requires_grad_(True)

        time_embeddings = self.positional_encoder.time_encoding_network.get_last_activations(
            self.positional_encoder.encode_time_coords(x[:,2])
        )
        time_embeddings_correlations = torch.einsum('ai,aj->ij', time_embeddings, time_embeddings) / x.size(0)
        time_regularizer = (
            torch.norm(
                time_embeddings_correlations - torch.diag(
                    torch.diag(time_embeddings_correlations)
                ).to(self.device), p=2
            )
        )
        
        space_embeddings = self.positional_encoder.space_encoding_network.get_last_activations(
            self.positional_encoder.space_encoder(x[:,:2])
        )
        
        space_embeddings_correlations = torch.einsum('ai,aj->ij', space_embeddings, space_embeddings) / x.size(0)

        space_regularizer = (
            torch.norm(
                space_embeddings_correlations - torch.diag(
                    torch.diag(space_embeddings_correlations)
                ).to(self.device), p=2
            )
        )
        
        final_embeddings = self.get_last_activations(x)

        final_embedding_correlations = torch.einsum('ai,aj->ij', final_embeddings, final_embeddings) / x.size(0)
        
        if self.normality_flag:
            final_regularizer = (
                torch.norm(final_embedding_correlations - torch.eye(final_embedding_correlations.size(0)).to(self.device), p=2)
            )
        else:
            final_regularizer = (
                torch.norm(final_embedding_correlations - torch.diag(
                    torch.diag(final_embedding_correlations).to(self.device)
                ), p=2)
            )

        return (final_regularizer**self.ortho_exponent, space_regularizer**self.ortho_exponent, time_regularizer**self.ortho_exponent)

    def training_step(self, batch, batch_idx):
        """
        Training step.

        Args:
            batch (tuple): Batch of data.
            batch_idx (int): Batch index.

        Returns:
            Tensor: Loss value.
        """
        loss = self.common_step(batch, batch_idx, mode="train")
        if self.regression:
            loss *= self.mean_var

        self.log("train_loss", loss, on_step=True, on_epoch=True)

        # compute ortho regularizer
        lat_lon_time, _ = batch
        final_ortho_regularizer, space_ortho_regularizer, time_ortho_regularizer = self.compute_ortho_regularizer(lat_lon_time)

        self.log("final_ortho_regularizer", final_ortho_regularizer, on_step=True, on_epoch=True)
        self.log("space_ortho_regularizer", space_ortho_regularizer, on_step=True, on_epoch=True)
        self.log("time_ortho_regularizer", time_ortho_regularizer, on_step=True, on_epoch=True)

        return (
            loss + self.ortho_weight*final_ortho_regularizer + 
                self.ortho_weight_space*space_ortho_regularizer + self.ortho_weight_time*time_ortho_regularizer
            )

    def validation_step(self, batch, batch_idx):
        # Temporarily enable gradients for validation so common_step can compute the derivative
        with torch.set_grad_enabled(True):
            loss = self.common_step(batch, batch_idx, mode="val")
            if self.regression:
                loss *= self.mean_var

        self.log("val_loss", loss, on_step=True, on_epoch=True)
    
        return {"val_loss": loss}

    def predict_step(self, batch, batch_idx):
        """
        Prediction step.

        Args:
            batch (tuple): Batch of data.
            batch_idx (int): Batch index.

        Returns:
            tuple: Prediction logits, coordinates, and labels.
        """
        lat_lon_time, label = batch
        prediction_logits = self.forward(lat_lon_time)
        return prediction_logits, lat_lon_time, label
    
    def test_step(self, batch, batch_idx):
        """
        Test step.

        Args:
            batch (tuple): Batch of data.
            batch_idx (int): Batch index.

        Returns:
            dict: Test results.
        """
        lat_lon_time, label = batch

        prediction_logits = self.forward(lat_lon_time)
        loss = self.loss_fn(self, lat_lon_time, label)

        if self.regression:
            # Ensure stds and means are on the same device as prediction_logits
            prediction_logits = prediction_logits.cpu() * self.stds.cpu() + self.means.cpu()
            label = label.cpu() * self.stds.cpu() + self.means.cpu()
            loss *= self.mean_var

        # check if binary
        non_binary_task = self.regression
        if (prediction_logits.size(1) == 1) and not (non_binary_task):
            y_pred = (prediction_logits.squeeze() > 0)
            average = "binary"
        elif self.regression:
            y_pred = prediction_logits.cpu()
        else:  # take argmax
            y_pred = prediction_logits.argmax(-1).cpu()
            average = "macro"

        self.log("test_loss", loss, on_step=True, on_epoch=True)
        if self.regression:
            nan_scaling_factor_label = 1.0
            # nan_scaling_factor_label = 1.0 / self.stds # undo the normalization of the labels

            MAE = np.mean(np.nanmean(np.abs(label.cpu() - y_pred.cpu()) / nan_scaling_factor_label, axis=0), axis=0)
            RMSE = np.sqrt(np.nanmean((label.cpu() - y_pred.cpu())**2 / nan_scaling_factor_label**2, axis=0)).mean()
            
            self.log("test_MAE", MAE, on_step=True, on_epoch=True)
            self.log("test_RMSE", RMSE, on_step=True, on_epoch=True)
            
            test_results = {
                "test_loss": loss,
                "test_MAE": MAE,
                "test_RMSE": RMSE,
                "y_pred": y_pred,
                "logits": prediction_logits,
            }
        else:
            accuracy = accuracy_score(y_true=label.cpu(), y_pred=y_pred)
            IoU = jaccard_score(y_true=label.cpu(), y_pred=y_pred, average=average)
            self.log("test_accuracy", accuracy, on_step=False, on_epoch=True)
            self.log("test_IoU", IoU, on_step=False, on_epoch=True)
            
            test_results = {
                          "test_loss":loss,
                          "test_accuracy": accuracy,
                          "logits": prediction_logits,
                          "y_pred": y_pred,
                          }
            
        return test_results

    def configure_optimizers(self):
        """
        Configure the optimizers.

        Returns:
            Optimizer: Configured optimizer.
        """
        optimizer = optim.Adam([{"params": self.neural_network.parameters()},
                                {"params": self.positional_encoder.parameters(), "weight_decay":0}],
                               lr=self.learning_rate,
                               weight_decay=self.weight_decay)
        return optimizer
    
class PE_3D(nn.Module):
    """
    Positional Encoder for 3D coordinates.
    """
    def build_artifact_path(self, time_encoding_args, spatial_encoding_args, combination_type, time_embedding_type):
        """
        Build the artifact path for saving model artifacts.

        Args:
            time_encoding_args (dict): Arguments for time encoding.
            spatial_encoding_args (dict): Arguments for spatial encoding.
            combination_type (str): Type of combination for embeddings.
            time_embedding_type (str): Type of time embedding.

        Returns:
            Path: The artifact path.
        """
        return (Path(".") /
            f"space_{spatial_encoding_args['name']}" /
            f"time_{time_encoding_args['name']}" /
            f"combination_{combination_type}" /
            f"time_embedding_{time_embedding_type}"
        )
    
    def __init__(
            self, 
            space_encoder, 
            combination_type, 
            time_embedding_type, 
            time_encoding_args, 
            spatial_encoding_args,
            number_of_timesteps,
            ):
        """
        Initialize the PE_3D module.

        Args:
            space_encoder (nn.Module): Space encoder module.
            combination_type (str): Type of combination for embeddings.
            time_embedding_type (str): Type of time embedding.
            time_encoding_args (dict): Arguments for time encoding.
            spatial_encoding_args (dict): Arguments for spatial encoding.
            number_of_timesteps (int): Number of timesteps.
        """
        super(PE_3D, self).__init__()

        self.space_encoder = space_encoder
        self.artifact_path = self.build_artifact_path(
            time_encoding_args=time_encoding_args,
            spatial_encoding_args=spatial_encoding_args,
            combination_type=combination_type,
            time_embedding_type=time_embedding_type
        )

        self.combination_type = combination_type
        self.time_embedding_dim = time_encoding_args["input_dim"]
        self.time_embedding_type = time_embedding_type

        self.time_encoding_network = get_neural_network_v2(
            **time_encoding_args
        )

        self.space_encoding_network = get_neural_network_v2(
            **spatial_encoding_args
        )

        self.number_of_timesteps = number_of_timesteps

        if combination_type == "concatenation":
            self.embedding_dim = spatial_encoding_args["output_dim"] + time_encoding_args["output_dim"] # self.space_encoder.embedding_dim + self.time_embedding_dim
        elif combination_type == "outer_product":
            # space-embeddings and time-embeddings are combined by an outer product-like operation which has the effect of multiplying the embedding dimensions,
            # adding this class variable is important to determine
            # the input dimension of the downstream learning module, i.e. neural network
            self.embedding_dim = spatial_encoding_args["output_dim"] * time_encoding_args["output_dim"]
        elif combination_type == "hadamard_product":
            if spatial_encoding_args["output_dim"] != time_encoding_args["output_dim"]:
                raise ValueError("The output dimensions of the spatial and time encodings must match for a Hadamard product.")
            self.embedding_dim = spatial_encoding_args["output_dim"]
        elif combination_type == "forget_time":
            # can be useful for debugging
            self.embedding_dim = spatial_encoding_args["output_dim"] 
        else:
            raise NotImplementedError(f"Combination type '{self.combination_type}' not implemented")
        
    def encode_time_coords(self, time_coords):
        # Validate input
        if not isinstance(time_coords, torch.Tensor):
            raise TypeError("time_coords must be a torch.Tensor")
        
        # Create the output in the same device & dtype as time_coords
        time_embedding = torch.zeros(
            (time_coords.shape[0], self.time_embedding_dim),
            dtype=time_coords.dtype,
            device=time_coords.device
        )

        if self.time_embedding_type in time_embedding_functions:
            embedding_fun = time_embedding_functions[self.time_embedding_type]
        elif self.time_embedding_type in time_embedding_functions_more_params:
            # Capture current value to avoid closure issues
            number_of_timesteps = self.number_of_timesteps
            embedding_fun = lambda degree, t: time_embedding_functions_more_params[self.time_embedding_type](
                degree=degree, t=t, number_of_timesteps=number_of_timesteps
            )
        else:
            raise NotImplementedError(f"Time embedding type '{self.time_embedding_type}' not implemented")

        for degree in range(self.time_embedding_dim):
            # No need for extra device transfer
            time_embedding[:, degree] = embedding_fun(degree, time_coords)

        return time_embedding

    def forward(self, coords):
        """
        Forward pass of the positional encoder.

        Args:
            coords (Tensor): Coordinates.

        Returns:
            Tensor: Combined embeddings.
        """
        space_coords = coords[:,:2]
        time_coords = coords[:, 2]

        space_embeddings = self.space_encoding_network(self.space_encoder(space_coords))     
        
        if self.combination_type == "forget_time":
            return space_embeddings
        
        time_embeddings = self.time_encoding_network(self.encode_time_coords(time_coords))

        if self.combination_type == "concatenation":
            combined_embeddings = torch.cat((space_embeddings, time_embeddings), dim=1)
        elif self.combination_type == "outer_product":
            # space-embeddings and time-embeddings are combined by an outer product-like operation in the next line
            combined_embeddings = torch.einsum('ai,aj -> aij', space_embeddings, time_embeddings).flatten(start_dim=1, end_dim=2)#.shape
        elif self.combination_type == "hadamard_product":
            if space_embeddings.size(1) != time_embeddings.size(1):
                raise ValueError(f"Space embedding dimension {space_embeddings.size(1)} doesn't match time embedding dimension {time_embeddings.size(1)}")
            combined_embeddings = space_embeddings * time_embeddings
        elif self.combination_type == "forget_time":
            # added for debugging purposes
            combined_embeddings = space_embeddings
        else:
            raise NotImplementedError(f"Combination type '{self.combination_type}' not implemented")
        
        return combined_embeddings

