import math
import torch
from torch import Tensor
import networkx as nx
from networkx.algorithms.shortest_paths.unweighted import single_source_shortest_path_length
import torch_geometric.transforms as T
from torch_geometric.utils import to_networkx
from torch_geometric.data import Data
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from typing import Dict
from models.gpn.layers import CertaintyDiffusion
from models.gpn.utils import apply_mask
from models.gpn.utils import Prediction
from models.gpn.utils import RunConfiguration, DataConfiguration
from models.gpn.utils import ModelConfiguration, TrainingConfiguration
from models.gpn.utils import Storage, ModelNotFoundError


class Model(nn.Module):
    """base model which provides functionality to load and store models, compute losses, specify matching optimizers, and much more"""

    def __init__(self, params: ModelConfiguration):
        super().__init__()
        self._expects_training = True
        self._is_warming_up = False
        self._is_finetuning = False

        if params is not None:
            self.params = params.clone()

        self.storage = None
        self.storage_params = None
        self.model_file_path = None
        self.cached_y = None

    def forward(self, data: Data, *_, **__) -> Prediction:
        x = self.forward_impl(data)
        log_soft = F.log_softmax(x, dim=-1)
        soft = torch.exp(log_soft)
        max_soft, hard = soft.max(dim=-1)

        # cache soft prediction for SGCN, for which a
        # model might act as teacher (e.g. GAT/GCN)
        self.cached_y = soft

        # ---------------------------------------------------------------------------------
        pred = Prediction(
            soft=soft,
            log_soft=log_soft,
            hard=hard,
            logits=x,
            # confidence of prediction
            prediction_confidence_aleatoric=max_soft,
            prediction_confidence_epistemic=None,
            prediction_confidence_structure=None,
            # confidence of sample
            sample_confidence_aleatoric=max_soft,
            sample_confidence_epistemic=None,
            sample_confidence_features=None,
            sample_confidence_structure=None
        )
        # ---------------------------------------------------------------------------------

        return pred

    def expects_training(self) -> bool:
        return self._expects_training

    def is_warming_up(self) -> bool:
        return self._is_warming_up

    def is_finetuning(self) -> bool:
        return self._is_finetuning

    def set_expects_training(self, flag: bool) -> None:
        self._expects_training = flag

    def set_warming_up(self, flag: bool) -> None:
        self._is_warming_up = flag

    def set_finetuning(self, flag: bool) -> None:
        self._is_finetuning = flag

    def get_num_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def forward_impl(self, data: Data, *args, **kwargs):
        raise NotImplementedError

    def loss(self, prediction: Prediction, data: Data) -> Dict[str, torch.Tensor]:
        return self.CE_loss(prediction, data)

    def warmup_loss(self, prediction: Prediction, data: Data) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def fintetune_loss(self, prediction: Prediction, data: Data) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def CE_loss(self, prediction: Prediction, data: Data, reduction='mean') -> Dict[str, torch.Tensor]:
        y_hat = prediction.log_soft
        y_hat, y = apply_mask(data, y_hat, split='train')

        return {
            'CE': F.nll_loss(y_hat, y, reduction=reduction)
        }

    def save_to_file(self, model_path: str) -> None:
        save_dict = {
            'model_state_dict': self.state_dict(),
            'cached_y': self.cached_y
        }
        torch.save(save_dict, model_path)

    def load_from_file(self, model_path: str) -> None:
        if not torch.cuda.is_available():
            c = torch.load(model_path, map_location=torch.device('cpu'))
        else:
            c = torch.load(model_path)
        self.load_state_dict(c['model_state_dict'])
        self.cached_y = c['cached_y']

    def get_optimizer(self, lr: float, weight_decay: float) -> optim.Adam:
        optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
        return optimizer

    def get_warmup_optimizer(self, lr: float, weight_decay: float) -> optim.Adam:
        raise NotImplementedError

    def get_finetune_optimizer(self, lr: float, weight_decay: float) -> optim.Adam:
        raise NotImplementedError

    def create_storage(self, run_cfg: RunConfiguration, data_cfg: DataConfiguration,
                       model_cfg: ModelConfiguration, train_cfg: TrainingConfiguration):
        """Create a storage object for saving model checkpoints"""
        
        if run_cfg.job == 'train' or (run_cfg.job == 'evaluate' and run_cfg.eval_experiment_name is None):
            run_cfg.set_values(eval_experiment_name=run_cfg.experiment_name)

        storage = Storage(run_cfg.experiment_directory,
                         experiment_name=run_cfg.eval_experiment_name)

        storage_params = {**model_cfg.to_dict(ignore=model_cfg.default_ignore()),
                         **data_cfg.to_dict(), **train_cfg.to_dict()}

        # ignore ood parameters for matching in an evasion setting
        if run_cfg.job == 'evaluate' and data_cfg.ood_flag and data_cfg.ood_setting == 'evasion':
            storage_params = {k: v for k, v in storage_params.items() if not k.startswith('ood_')}

        self.storage = storage
        self.storage_params = storage_params

    def load_from_storage(self) -> None:
        if self.storage is None:
            raise ModelNotFoundError('Error on loading model, storage does not exist!')

        model_file_path = self.storage.retrieve_model_file_path(
            self.storage_params['model_name'], self.storage_params, init_no=self.params.init_no
        )

        self.load_from_file(model_file_path)

    def save_to_storage(self) -> None:
        if self.storage is None:
            raise ModelNotFoundError('Error on storing model, storage does not exist!')

        model_file_path = self.storage.create_model_file_path(
            self.storage_params['model_name'], self.storage_params, init_no=self.params.init_no
        )

        self.save_to_file(model_file_path)



class GDK(Model):
    """simple parameterless baseline for node classification based on the Graph Dirichlet Kernel"""

    def __init__(self, params: ModelConfiguration):
        super().__init__(params)
        self.cached_alpha = None

    def forward(self, data: Data) -> Prediction:
        return self.forward_impl(data)

    def forward_impl(self, data: Data) -> Prediction:
        if self.cached_alpha is None:
            distance_evidence = compute_kde(data, self.params.num_classes, sigma=1.0)
            alpha = 1.0 + distance_evidence
            self.cached_alpha = alpha

        else:
            alpha = self.cached_alpha
            distance_evidence = alpha - 1.0

        soft = alpha / alpha.sum(-1, keepdim=True)
        max_soft, hard = soft.max(-1)

        # ---------------------------------------------------------------------------------
        pred = Prediction(
            # prediction and intermediary scores
            soft=soft,
            hard=hard,
            alpha=alpha,

            # prediction confidence scores
            prediction_confidence_aleatoric=max_soft,
            prediction_confidence_epistemic=alpha[torch.arange(hard.size(0)), hard],
            prediction_confidence_structure=distance_evidence[[torch.arange(hard.size(0)), hard]],

            # sample confidence scores
            sample_confidence_aleatoric=max_soft,
            sample_confidence_epistemic=alpha.sum(-1),
            sample_confidence_features=None,
            sample_confidence_structure=distance_evidence.sum(-1),
        )
        # ---------------------------------------------------------------------------------

        return pred

    def expects_training(self) -> bool:
        return False

    def save_to_file(self, model_path: str) -> None:
        assert self.cached_alpha is not None
        torch.save(self.cached_alpha, model_path)

    def load_from_file(self, model_path: str) -> None:
        if not torch.cuda.is_available():
            alpha = torch.load(model_path, map_location=torch.device('cpu'))
        else:
            alpha = torch.load(model_path)
        self.cached_alpha = alpha


def kernel_distance(x: Tensor, sigma: float = 1.0) -> Tensor:
    sigma_scale = 1.0 / (sigma * math.sqrt(2 * math.pi))
    k_dis = torch.exp(-torch.square(x)/ (2 * sigma * sigma))
    return sigma_scale * k_dis


def compute_kde(data: Data, num_classes: int, sigma: float = 1.0) -> Tensor:
    transform = T.AddSelfLoops()
    data = transform(data)
    n_nodes = data.y.size(0)

    idx_train = torch.nonzero(data.splits['train'], as_tuple=False).squeeze().tolist()
    evidence = torch.zeros((n_nodes, num_classes), device=data.y.device)
    G = to_networkx(data, to_undirected=True)

    for idx_t in idx_train:
        distances = single_source_shortest_path_length(G, source=idx_t, cutoff=10)
        distances = torch.Tensor(
            [distances[n] if n in distances else 1e10 for n in range(n_nodes)]).to(data.y.device)
        evidence[:, data.y[idx_t]] += kernel_distance(distances, sigma=sigma).unsqueeze(1)

    return evidence
