import time

import torch
from torch import nn
import torch.optim as optim
from torch_geometric.data import DataLoader
from tqdm import tqdm

from src.evaluation.eval_scores import Scorer
from src.utils.initialise_service import init_optimizer
from src.utils.model_service import ModelService
from src.training.loss_functions import WeightedMSELoss


def calculate_class_weights(train_loader):
    """
    Calculate class weights for the loss function based on the training data.

    Args:
        train_loader: DataLoader for the training data.

    Returns:
        torch.Tensor: Class weights.
    """
    sample_distribution = dict()

    # TODO Implement masks
    for batch in train_loader:
        if len(batch.y.shape) > 1:
            y_train = batch.y * batch.train_mask if hasattr(batch, 'train_mask') else batch.y
        else:
            y_train = batch.y

        # Count the occurrences of each class in the batch
        for label in y_train.unique():
            if label.item() not in sample_distribution:
                sample_distribution[label.item()] = 0
            sample_distribution[label.item()] += (y_train == label).sum().item()

    # Calculate the class weights
    total_samples = sum(sample_distribution.values())
    class_weights = {k: v / total_samples for k, v in sample_distribution.items()}

    # Convert to tensor
    class_weights_tensor = torch.tensor(list(class_weights.values()), dtype=torch.float32)

    return class_weights_tensor


class Trainer:
    """
    Base class for training models.
    """

    optimiser: optim.Optimizer

    def __init__(self,
                 epochs: int,
                 learning_rate: float,
                 training_loader,
                 validation_loader,
                 model_wrapper,
                 device,
                 logger,
                 scorer: Scorer,
                 loss: str = "mse",
                 optimizer: str = "adam",
                 lr_scheduler_params: dict = None,
                 weight_decay: float = 0,
                 loss_class_weighting: str = None,
                 momentum: float = 0,
                 seed: int = None,
                 patience: int = 50,
                 log_image_frequency: int = None,
                 validation_frequency: int = 1,
                 min_lr: float = 0.0,
                 **kwargs):

        if loss_class_weighting == "balanced":
            class_weights_tensor = calculate_class_weights(training_loader)
        else:
            class_weights_tensor = None

        # Setting up the loss
        if loss == "mse":
            loss_instance = nn.MSELoss()
        elif loss == "crossentropy":
            loss_instance = nn.CrossEntropyLoss(weight=class_weights_tensor)
        elif loss == "nll":
            loss_instance = nn.NLLLoss(weight=class_weights_tensor)
        elif loss == "weighted_mse":
            print(f"[TRAINER]: Class weights are not available, defaulting to MSELoss")
            loss_instance = nn.MSELoss()
        elif loss == "binarycrossentropy":
            loss_instance = nn.BCELoss(weight=class_weights_tensor)
        elif loss == "logitBCE":
            loss_instance = nn.BCEWithLogitsLoss(weight=class_weights_tensor)
        else:
            print(f"[TRAINER]: Loss {loss} is not valid, defaulting to MSELoss")
            loss_instance = nn.MSELoss()
        print(f"[TRAINER]: Using {loss} loss function.")

        # Setting up the optimizer
        optimizer_instance, lr_scheduler = init_optimizer(epochs=epochs,
                                                          learning_rate=learning_rate,
                                                          lr_scheduler_params=lr_scheduler_params,
                                                          model=model_wrapper.model,
                                                          momentum=momentum,
                                                          optimizer=optimizer,
                                                          weight_decay=weight_decay)
        if isinstance(lr_scheduler, optim.lr_scheduler.ReduceLROnPlateau) and "min_lr" in lr_scheduler_params:
            min_lr = lr_scheduler_params['min_lr']

        # Reset peak memory stats to measure GPU peak RAM
        if device == torch.device("cuda"):
            torch.cuda.reset_peak_memory_stats()

        # Set training parameters
        self.epochs = epochs
        self.loss = loss_instance
        self.optimizer = optimizer_instance
        self.lr_scheduler = lr_scheduler
        self.min_lr = min_lr
        self.seed = seed
        self.validation_frequency = validation_frequency
        self.patience = patience # // validation_frequency  # Adjust patience to validation frequency
        self.log_image_frequency = log_image_frequency
        self.training_loader = training_loader
        self.validation_loader = validation_loader
        self.model_wrapper = model_wrapper
        self.device = device
        self.logger = logger
        self.scorer = scorer
        self.kwargs = kwargs
        print("[TRAINER]: Trainer was successfully set up.")

    def start_training(self) -> float:
        """
        This is the entrypoint method to start the training process for the model.

        This method first moves the model and loss function to the device. The
        method sets up the data loaders using the `setup_dataloaders`
        method. Afterward, it starts the actual training using the `train_model` method and
        logs the reason for finishing the training. After the training process is finished,
        the method closes the logger.
        """
        self.model_wrapper.to(self.device)
        self.loss.to(self.device)

        # Perform model training
        self.logger.write_training_start()
        finish_reason, min_crit = self.train_model(self.training_loader, self.validation_loader)
        self.logger.write_training_end(finish_reason)
        self.logger.close()

        return min_crit

    def train_model(self, train_loader: DataLoader, validation_loader: DataLoader) -> tuple[str, float]:
        """
        Trains the model for a specified number of epochs. For each epoch, the method calculates
        the training loss and validation loss, logs these losses, and saves the current state
        of the model.

        If a `KeyboardInterrupt` is raised during training, the method catches it and sets the
        finish reason to `"Training interrupted by user"`. If the training finishes without
        interruption, the finish reason is set to `"Training finished normally"`.

        Args:
            train_loader (DataLoader): DataLoader for the training graphs.
            validation_loader (DataLoader): DataLoader for the validation graphs.

        Returns:
            str: The reason the training ended.
            float: The minimum value of criterion achieved on validation set during training.
        """
        # Setup for early stopping
        min_crit = float('inf')
        cur_patience = 0
        cur_lr = self.optimizer.param_groups[0]['lr']

        finish_reason = "Training terminated before training loop ran through."

        for epoch in tqdm(range(self.epochs)):
            try:
                epoch_start_time = time.time()

                train_loss, train_results = self.train_step(train_loader)
                # Logging loss of results
                self.logger.log_loss(train_loss, epoch, "1_train")

                # Measure training epoch time
                epoch_end_time = time.time()
                epoch_time = (epoch_end_time - epoch_start_time)
                self.logger.log_performance(name="epoch_duration_seconds", value=epoch_time, epoch=epoch)

                if epoch % self.validation_frequency == 0:
                    # Calculating validation loss if loader exists
                    val_loss, val_results = self.validation_step(validation_loader)
                    self.logger.log_loss(val_loss, epoch, "2_validation")

                    # Calculate scores
                    criterion, val_scores = self.scorer(targets=val_results[1], predictions=val_results[0])

                    # Use validation loss as early stopping criterion if criterion is None
                    if criterion is None:
                        criterion = val_loss

                    # TODO Move loop to logger?
                    for score, score_dict in val_scores.items():
                        for class_label, value in score_dict.items():
                            self.logger.log_test_score(value=value, epoch=epoch, class_label=class_label, score=score)

                # Visualise results
                if (self.log_image_frequency is not None) and (epoch % self.log_image_frequency == 0):
                    self.visualize(train_results[0], train_results[1], 'train', epoch)
                    if validation_loader is not None:
                        self.visualize(val_results[0], val_results[1], "val", epoch)

                # Step for ReduceLROnPlateau schedule is done with validation loss
                if self.lr_scheduler is not None and isinstance(self.lr_scheduler,
                                                                optim.lr_scheduler.ReduceLROnPlateau):
                    self.lr_scheduler.step(val_loss, epoch=epoch)

                # Logging the learning rate
                if self.lr_scheduler is not None:
                    try:
                        cur_lr = self.lr_scheduler.get_last_lr()[0]
                    except AttributeError:
                        cur_lr = self.lr_scheduler.optimizer.param_groups[0]['lr']
                self.logger.log_lr(lr=cur_lr, epoch=epoch)

                # Early stopping
                if min_crit > criterion:
                    # Save model and reset patience if criterion is optimal
                    min_crit = criterion
                    cur_patience = 0

                    self.save_model()
                    self.logger.set_optimal_epoch(epoch)
                    tqdm_string = ", NEW BEST MODEL SAVED!"
                else:
                    if self.patience > 0:
                        # Increase patience if criterion is not optimal
                        cur_patience += 1

                        if (cur_patience == self.patience) or (cur_lr < self.min_lr):
                            print(f"Early stopping at epoch {epoch}.")
                            finish_reason = "Training finished because of early stopping."

                            # Log same values for remaining epochs
                            for i in range(epoch + 1, self.epochs):
                                self.logger.log_loss(train_loss, i, "1_train")
                                if validation_loader is not None:
                                    self.logger.log_loss(val_loss, i, "2_validation")
                                for score, score_dict in val_scores.items():
                                    for class_label, value in score_dict.items():
                                        self.logger.log_test_score(value=value, epoch=i, class_label=class_label,
                                                                   score=score)

                            break

                tqdm.write(
                    f"[TRAINER]: Training Loss = {train_loss:.5f},  Validation Loss = {val_loss:.5f}{tqdm_string}")
                tqdm_string = ""

            except KeyboardInterrupt:
                finish_reason = "Training interrupted by user input."
                break

        # Overwrite finish reason if training was not finished due to early
        # stopping or user input
        if finish_reason == "Training terminated before training loop ran through.":
            finish_reason = "Training was normally completed."

        # Log max GPU memory usage
        if self.device == torch.device("cuda"):
            peak_memory_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
            peak_memory_mb_reserved = torch.cuda.max_memory_reserved() / (1024 ** 2)
        else:
            peak_memory_mb = 1
            peak_memory_mb_reserved = 1
        self.logger.log_performance(name="peak_memory_mb", value=peak_memory_mb, epoch=self.epochs)
        self.logger.log_performance(name="peak_memory_mb_reserved", value=peak_memory_mb_reserved, epoch=self.epochs)

        return finish_reason, min_crit

    def train_step(self, dataloader) -> tuple[float, tuple[torch.Tensor, torch.Tensor]]:
        """
        Calculates the training loss for the model. This method is called during each epoch.

        This method iterates over each batch in the training loader. For each batch, it resets the optimizer,
        calculates the loss between the predictions and the actual targets, performs backpropagation, and updates the
        model's parameters. The forward function is computed for the whole dataset. The train and test mask are used
        to separate the dataset into training and test data. The method accumulates the total training loss and
        returns the average training loss per batch.

        Returns:
            float: The average training loss per batch.
            float: The average test loss per batch.
        """
        self.model_wrapper.train()
        total_train_loss: float = 0
        step_count: int = 0
        train_predictions = []
        train_targets = []

        for batch in dataloader:
            # Reset optimizer
            self.optimizer.zero_grad()

            pred, targ = self.model_wrapper.calc_batch(batch)

            if hasattr(batch, 'train_mask'):
                train_mask = batch.train_mask
                pred = pred[train_mask]
                targ = targ[train_mask]
            loss = self.loss(pred, targ)

            # Backpropagation
            loss.backward()
            self.optimizer.step()

            total_train_loss += loss.item()
            step_count += 1

            train_predictions.append(pred)
            train_targets.append(targ)

        train_predictions = torch.cat(train_predictions)
        train_targets = torch.cat(train_targets)
        total_train_loss = total_train_loss / step_count

        return total_train_loss, (train_predictions, train_targets)

    def validation_step(self, validation_loader) -> tuple[float, tuple[torch.Tensor, torch.Tensor]]:
        """
        Calculates the target metric for the test set and generates visualisations for the train and the test set. This method is called in the frequency given in the config.

        This method iterates over each batch in the dataloader, computes the model's
        predictions for the batch, calculates the accuracy between the predictions and the actual
        targets, and accumulates the total accuracy. The method returns the average
        accuracy over all batches.

        Returns:
            float: The loss over all batches.
            float: The accuracy over all batches.
            float: Validation loss
        """
        self.model_wrapper.eval()
        total_val_loss: float = 0
        step_count: int = 0
        val_predictions = []
        val_targets = []

        with torch.no_grad():
            for batch in validation_loader:
                pred, targ, _ = self.model_wrapper.calc_batch(batch)

                if hasattr(batch, 'val_mask'):
                    val_mask = batch.val_mask
                    pred = pred[val_mask]
                    targ = targ[val_mask]

                val_loss = self.loss(pred, targ)

                total_val_loss += val_loss.item()
                step_count += 1

                val_predictions.append(pred)
                val_targets.append(targ)

        val_predictions = torch.cat(val_predictions)
        val_targets = torch.cat(val_targets)

        total_val_loss = total_val_loss / step_count

        return total_val_loss, (val_predictions, val_targets)

    def visualize(self, predictions, targets, set: str, epoch: int) -> None:
        """
        This method visualizes the results of the training and test set.

        Visualizes the results of the training and test set by creating confusion matrices and plots.

        :param data_loader: Dataloader with training data
        :param epoch: Epoch number
        :return: None
        """
        set_idx = 1 if set == 'train' else 2 if set == 'val' else 3 if set == 'test' else 4

        # Transform to np.array
        predictions = predictions.detach().cpu().numpy()
        targets = targets.detach().cpu().numpy()

        use_continuous_matrix = (isinstance(self.loss, nn.MSELoss) or isinstance(self.loss, WeightedMSELoss))

        self.logger.save_confusion_matrix(targets,
                                          predictions,
                                          labels=self.training_loader.dataset.dataset.num_classes,
                                          epoch=epoch,
                                          continuous=use_continuous_matrix,
                                          set=f'{set_idx}_{set}')


    def save_model(self, new_version: bool = False) -> None:
        """
        This method uses the `save_model` function to save the trained model to a file.
        After the model is saved, the method logs a message to the console with the path
        to the file.
        """
        # TODO avoid new saves for crossvald and grid search
        run_name = self.logger.name.split('/')[1][20:]
        self.model_path = self.model_wrapper.save_model()
        self.logger.log_model_path(model_path=self.model_path)
