import time

import torch
from torch_geometric.data import DataLoader

from src.evaluation.eval_measurements import Measurements
from src.evaluation.eval_scores import Scorer
from src.logger.logger import Logger


class Evaluator:
    """
    Base class for training models.
    """

    def __init__(self,
                 test_loader: DataLoader,
                 model_wrapper,
                 device,
                 logger: Logger,
                 scorer: Scorer,
                 measurements: Measurements,
                 train_loader: DataLoader=None,
                 val_loader: DataLoader=None,
                 **kwargs):

        # Reset peak memory stats to measure GPU peak RAM
        if device == torch.device("cuda"):
            torch.cuda.reset_peak_memory_stats()

        # Set training parameters
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.model_wrapper = model_wrapper
        self.device = device
        self.logger = logger
        self.kwargs = kwargs
        self.scorer = scorer
        self.measurements = measurements
        print("[TRAINER]: Trainer was successfully set up.")

    def evaluate(self) -> tuple[str, float]:

        self.model_wrapper.to(self.device)
        criterion = None

        finish_reason = "Training terminated before training loop ran through."

        try:
            start_time = time.time()

            test_predictions, test_targets, extra_measures = self.test_step(self.test_loader)

            end_time = time.time()
            inference_time = end_time - start_time
            if extra_measures is not None:
                inference_time = inference_time - extra_measures['measurement_time']
            self.logger.log_performance(name="inference_time_seconds", value=inference_time, epoch=0)

            criterion, test_scores = self.scorer(targets=test_targets, predictions=test_predictions)

            for score, score_dict in test_scores.items():
                for class_label, value in score_dict.items():
                    self.logger.log_test_score(value=value, epoch=0, class_label=class_label, score=score)

            if extra_measures is not None:
                for measure, value in extra_measures.items():
                    self.logger.log_measure(name=measure, value=value, epoch=0)

        except KeyboardInterrupt:
            finish_reason = "Training interrupted by user input."

        if finish_reason == "Training terminated before training loop ran through.":
            finish_reason = "Training was normally completed."

        # Log max GPU memory usage
        if self.device == torch.device("cuda"):
            peak_memory_MB = torch.cuda.max_memory_allocated() / (1024 ** 2)
        else:
            peak_memory_MB = 1
        self.logger.log_performance(name="peak_memory_mb", value=peak_memory_MB, epoch=0)

        return finish_reason, criterion

    def test_step(self, test_loader) -> tuple[torch.Tensor, torch.Tensor, dict]:

        self.model_wrapper.eval()
        test_predictions = []
        test_targets = []
        extra_measures = []
        measurement_time = 0

        with torch.no_grad():
            for batch in test_loader:
                pred, targ, emb = self.model_wrapper.calc_batch(batch)

                if hasattr(batch, 'test_mask'):
                    test_mask = batch.test_mask
                    pred = pred[test_mask]
                    targ = targ[test_mask]
                    emb = emb

                # Make measurements for test set
                start_time = time.time()
                if emb is not None and self.measurements is not None:
                    extra_measures.append(self.measurements(batch, emb))
                end_time = time.time()
                measurement_time -= (start_time - end_time)

                test_predictions.append(pred)
                test_targets.append(targ)

        # Calc embedding for training and validation as well, if not the same
        start_time = time.time()
        if emb is not None and self.measurements is not None:
            with torch.no_grad():
                for loader in [self.train_loader, self.val_loader]:
                    if loader.sampler.data_source.indices != self.test_loader.sampler.data_source.indices:
                        for batch in loader:
                            _, _, emb = self.model_wrapper.calc_batch(batch)
                            extra_measures.append(self.measurements(batch, emb))
        end_time = time.time()
        measurement_time -= (start_time - end_time)


        test_predictions = torch.cat(test_predictions)
        test_targets = torch.cat(test_targets)

        # Average measures
        if self.measurements is not None:
            start_time = time.time()
            measure_keys =  list(extra_measures[0].keys())
            measure_keys.remove('n')
            measures_means = {
                m: sum([d[m]*d['n'] for d in extra_measures]) / sum(d['n'] for d in extra_measures) for m in measure_keys
            } if extra_measures else {}
            end_time = time.time()
            measurement_time -= (start_time - end_time)
            measures_means['measurement_time'] = measurement_time
        else:
            measures_means = None

        return test_predictions, test_targets, measures_means
