import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'

import torch
import numpy as np
from collections import defaultdict


class TrainingCallback:
    """Base class for callbacks used during training"""
    
    def before_training(self, model, num_epochs):
        """Called before training starts"""
        pass
    
    def before_epoch(self, current, num_iterations):
        """Called before each training epoch"""
        pass
    
    def after_batch(self, metrics):
        """Called after each training batch"""
        pass
    
    def after_epoch(self, metrics):
        """Called after each training epoch"""
        pass
    
    def after_training(self):
        """Called after training ends"""
        pass


class PredictionCallback:
    """Base class for callbacks used during prediction"""
    
    def before_predictions(self, model, num_batches):
        """Called before predictions start"""
        pass
    
    def after_batch(self, metrics):
        """Called after each prediction batch"""
        pass
    
    def after_predictions(self):
        """Called after predictions end"""
        pass


class ValueTrainingCallback:
    """Callback that provides a value that can be read during training"""
    
    def read(self):
        """Read the current value"""
        raise NotImplementedError


class History(TrainingCallback):
    """Callback that tracks training history"""
    
    def __init__(self):
        self.history = defaultdict(list)
    
    def after_epoch(self, metrics):
        for key, value in metrics.items():
            self.history[key].append(value)


class CallbackException(Exception):
    """Exception raised by callbacks to interrupt training"""
    
    def __init__(self, message, verbose=False):
        super().__init__(message)
        self.verbose = verbose
    
    def print(self):
        if self.verbose:
            print(str(self))


def _strip_metrics(metrics):
    """Strip metrics of non-scalar values"""
    if isinstance(metrics, dict):
        return {k: v for k, v in metrics.items() if isinstance(v, (int, float))}
    return metrics


def _recursive_apply(fn, x):
    """Apply a function recursively to all tensors in x"""
    if isinstance(x, torch.Tensor):
        return fn(x)
    elif isinstance(x, (list, tuple)):
        return type(x)(_recursive_apply(fn, y) for y in x)
    elif isinstance(x, dict):
        return {k: _recursive_apply(fn, v) for k, v in x.items()}
    else:
        return x


def gpu_device(gpu):
    """Get the GPU device"""
    if gpu is None:
        return 'cpu'
    if isinstance(gpu, int):
        return f'cuda:{gpu}'
    if gpu == 'auto':
        return 'cuda' if torch.cuda.is_available() else 'cpu'
    return gpu


def apply_mask(data, mask):
    """Apply a mask to data"""
    if hasattr(data, 'mask'):
        data.mask = mask
    else:
        data.x = data.x[mask]
        data.y = data.y[mask]
    return data


class TransductiveGraphEngine:
    def __init__(self, model, splits=('train', 'test', 'val', 'all')):
        self.model = model
        self.splits = splits
        self.current_it = 0
        self.device = None

    def supports_multiple_gpus(self):
        return False

    def before_epoch(self, current, num_iterations):
        self.current_it = 0

    def after_batch(self, *_):
        if self.current_it is not None:
            self.current_it += 1

    def after_epoch(self, metrics):
        self.current_it = 0

    def _setup_device(self, gpu):
        self.device = gpu_device(gpu[0] if isinstance(gpu, list) else gpu)

    def _gpu_descriptor(self, gpu):
        if gpu == 'auto':
            return 0 if torch.cuda.is_available() else None
        return gpu

    def _exec_callbacks(self, callbacks, method, *args, **kwargs):
        for callback in callbacks:
            if hasattr(callback, method):
                getattr(callback, method)(*args, **kwargs)

    def train(self, train_data, val_data=None, epochs=20, eval_every=None,
              eval_train=False, eval_val=True, callbacks=None, metrics=None, gpu='auto', **kwargs):

        if metrics is None:
            metrics = {}
        if callbacks is None:
            callbacks = []

        # 1) Setup
        try:
            batch_iterations = len(train_data)
            iterable_data = False
        except: # pylint: disable=bare-except
            batch_iterations = eval_every
            iterable_data = True

        exception = None
        if iterable_data and eval_every is not None:
            # Here, epochs are considered iterations
            epochs = epochs // eval_every

        # 1.1) Callbacks
        history = History()
        # Prepend the engine's callbacks to the passed callbacks
        callbacks = [history] + callbacks
        # Also, add the callbacks that are extracted from the keyword arguments
        callbacks += [v for _, v in kwargs.items() if isinstance(v, TrainingCallback)]
        # Then, we can extract the callbacks for training and prediction
        train_callbacks = [c for c in callbacks if isinstance(c, TrainingCallback)]
        prediction_callbacks = [c for c in callbacks if isinstance(c, PredictionCallback)]
        self._exec_callbacks(train_callbacks, 'before_training', self.model, epochs)

        # 1.2) Metrics
        val_metrics = metrics

        # 1.3) Data loading
        if iterable_data:
            train_iterator = iter(train_data)

        # 1.4) GPU support
        gpu = self._gpu_descriptor(gpu)
        self._setup_device(gpu)
        self.model.to(self.device)

        # 1.5) Valid kwargs
        train_kwargs = {k: v for k, v in kwargs.items() if not k.startswith('eval_')}
        dynamic_train_kwargs = {
            k: v for k, v in train_kwargs.items() if isinstance(v, ValueTrainingCallback)
        }
        eval_kwargs = {k[5:]: v for k, v in kwargs.items() if k.startswith('eval_')}
        dynamic_eval_kwargs = {
            k: v for k, v in eval_kwargs.items() if isinstance(v, ValueTrainingCallback)
        }

        # 2) Train for number of epochs
        for current_epoch in range(epochs):
            # 2.1) Prepare
            try:
                self._exec_callbacks(
                    train_callbacks, 'before_epoch', current_epoch, batch_iterations
                )
            except CallbackException as e:
                exception = e
                break

            # 2.2) Train
            self.model.train()

            batch_losses = []
            if not iterable_data:
                train_iterator = iter(train_data)

            for i in range(batch_iterations):
                train_kwargs_batch = {
                    **train_kwargs,
                    **{k: v.read() for k, v in dynamic_train_kwargs.items()}
                }
                item = next(train_iterator)
                item = self.to_device(self.device, item)
                loss = self.train_batch(item, **train_kwargs_batch)
                batch_losses.append(loss)
                self._exec_callbacks(train_callbacks, 'after_batch', _strip_metrics(loss))

            # 2.3) Validate
            epoch_metrics = self.collate_losses(batch_losses)
            eval_kwargs_batch = {
                **eval_kwargs,
                **{k: v.read() for k, v in dynamic_eval_kwargs.items()}
            }
            do_val = eval_every is None or iterable_data or \
                current_epoch % eval_every == 0 or current_epoch == epochs - 1

            if val_data is not None and do_val:
                eval_metrics = self.evaluate(
                    val_data, metrics=val_metrics,
                    callbacks=prediction_callbacks, gpu=None, **eval_kwargs_batch
                )

                eval_metrics_val = eval_metrics['val']
                epoch_metrics = {**epoch_metrics, **{f'val_{k}': v for k, v in eval_metrics_val.items()}}

                if eval_train:
                    eval_metrics_train = eval_metrics['train']
                    epoch_metrics = {**epoch_metrics, **{f'train_{k}': v for k, v in eval_metrics_train.items()}}

            # 2.4) Finish epoch
            try:
                self._exec_callbacks(train_callbacks, 'after_epoch', epoch_metrics)
            except CallbackException as e:
                exception = e
                break

        # 3) Finish training
        # 3.1) If GPU used
        if gpu is not None:
            self.model.to('cpu', non_blocking=True)
            self.device = None

        # 3.2) Finish callbacks
        self._exec_callbacks(train_callbacks, 'after_training')
        if exception is not None:
            if isinstance(exception, CallbackException):
                exception.print()
            else:
                print(exception)

        return history

    def evaluate(self, data, metrics=None, callbacks=None, gpu='auto', **kwargs):
        if metrics is None:
            metrics = {}

        evals = self._get_evals(data, gpu=gpu, callbacks=callbacks, **kwargs)
        return self._aggregate_metrics(evals, metrics)

    def evaluate_target_and_ood(self, data, data_ood, metrics=None, metrics_ood=None, 
                               callbacks=None, gpu='auto', target_as_id=True, **kwargs):
        if metrics is None:
            metrics = {}
        if metrics_ood is None:
            metrics_ood = {}

        evals = self._get_evals(data, callbacks=callbacks, gpu=gpu, **kwargs)
        evals_ood = self._get_evals(
            data_ood, callbacks=callbacks,
            gpu=gpu, split_prefix='ood', **kwargs)

        if target_as_id:
            # target represents ID values, e.g. when evaluating isolated perturbations
            # or leave-out-class experiments
            evals_id = evals
        else:
            # for usual evasion setting: id values correspond to non-perturbed nodes
            # while ood nodes correspond to perturbed nodes
            # target corresponds to evaluation of all nodes without this distinction
            evals_id = self._get_evals(
                data_ood, callbacks=callbacks,
                gpu=gpu, split_prefix='id', **kwargs)

        results = self._aggregate_metrics(evals, metrics)
        ood_results = self._aggregate_metrics_ood(evals_id, evals_ood, metrics_ood)

        for s in self.splits:
            results[s] = {**results[s], **ood_results[s]}

        return results

    def predict(self, data, callbacks=None, gpu='auto', parallel=True, **kwargs):
        if callbacks is None:
            callbacks = []

        # 1) Set gpu if all is specified
        gpu = self._gpu_descriptor(gpu)

        # 2) Setup data loading
        num_iterations = len(data)

        self._exec_callbacks(callbacks, 'before_predictions', self.model, num_iterations)

        # 3) Now perform predictions
        # sequential computation
        device = gpu_device(gpu[0] if isinstance(gpu, list) else gpu)
        self.model.to(device)

        predictions = []

        iterator = iter(data)
        for _ in range(num_iterations):
            item = next(iterator)
            item = self.to_device(device, item)

            with torch.no_grad():
                out = self.predict_batch(item, **kwargs)
            out = out.to('cpu')

            predictions.append(out)
            self._exec_callbacks(callbacks, 'after_batch', None)

        self._exec_callbacks(callbacks, 'after_predictions')

        return self.collate_predictions(predictions)

    def eval_batch(self, data, split_prefix=None, **kwargs):
        self.model.eval()
        
        splits = []
        evals = []
        
        for split in self.splits:
            if hasattr(data, split + '_mask') and getattr(data, split + '_mask') is not None:
                split_data = apply_mask(data, getattr(data, split + '_mask'))
                with torch.no_grad():
                    eval_result = self.predict_batch(split_data, **kwargs)
                
                if split_prefix is not None:
                    splits.append(f"{split_prefix}_{split}")
                else:
                    splits.append(split)
                evals.append(eval_result)
        
        return splits, evals

    def train_batch(self, data, optimizer, loss=None, **kwargs):
        optimizer.zero_grad()
        
        # Forward pass
        out = self.model(data)
        
        # Calculate loss
        if loss is None:
            # Default to model's loss computation if no loss function provided
            loss_val = self.model.loss(out, data.y)
        else:
            loss_val = loss(out, data.y)
        
        # Backward pass
        loss_val.backward()
        optimizer.step()
        
        return {"loss": loss_val.item()}

    def predict_batch(self, data, **kwargs):
        self.model.eval()
        with torch.no_grad():
            return self.model(data, **kwargs)

    def collate_evals(self, evals):
        return evals[0] if len(evals) == 1 else torch.cat(evals, dim=0)

    def collate_losses(self, losses):
        result = {}
        for loss in losses:
            for k, v in loss.items():
                if k not in result:
                    result[k] = []
                result[k].append(v)
        
        return {k: np.mean(v) for k, v in result.items()}

    def collate_predictions(self, predictions):
        if len(predictions) == 1:
            return predictions[0]
        return torch.cat(predictions, dim=0)

    def _aggregate_metrics(self, evals, metrics):
        results = {split: {} for split in self.splits}
        
        for split, eval_result in zip(evals['splits'], evals['evals']):
            for metric_name, metric_fn in metrics.items():
                if split in results:
                    results[split][metric_name] = metric_fn(eval_result)
        
        return results

    def _aggregate_metrics_ood(self, evals, evals_ood, metrics_ood):
        results = {split: {} for split in self.splits}
        
        for metric_name, metric_fn in metrics_ood.items():
            for split, split_ood in zip(evals['splits'], evals_ood['splits']):
                if split in results and split_ood in evals_ood['splits']:
                    id_preds = evals['evals'][evals['splits'].index(split)]
                    ood_preds = evals_ood['evals'][evals_ood['splits'].index(split_ood)]
                    results[split][metric_name] = metric_fn(id_preds, ood_preds)
        
        return results

    def _get_evals(self, data, callbacks=None, gpu='auto', **kwargs):
        if callbacks is None:
            callbacks = []
        
        # Setup GPU
        gpu = self._gpu_descriptor(gpu)
        device = gpu_device(gpu[0] if isinstance(gpu, list) else gpu)
        self.model.to(device)
        
        # Move data to device
        data = self.to_device(device, data)
        
        # Evaluate
        self.model.eval()
        splits, evals = self.eval_batch(data, **kwargs)
        
        return {'splits': splits, 'evals': evals}

    def _collate(self, items):
        if not items:
            return None
        
        if isinstance(items[0], torch.Tensor):
            return torch.cat(items, dim=0)
        elif isinstance(items[0], dict):
            result = {}
            for key in items[0].keys():
                result[key] = self._collate([item[key] for item in items])
            return result
        else:
            return items

    def to_device(self, device, item):
        def _to_device(x):
            if hasattr(x, 'to'):
                return x.to(device)
            return x
            
        return _recursive_apply(_to_device, item)
