from typing import Callable, Union
import numpy as np
from torch import Tensor
from torch.optim import Optimizer
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Dataset
from models.grand import GRAND
from torch_geometric.transforms import GDC

from libs.reporter import (
    ReporterInterface,
    PrintReporter,
    SIMPLE_TEMPLATE,
    SUMMARY_TEMPLATE,
)


class NodeClassificationExp:
    def __init__(
        self,
        model: nn.Module,
        lossfn: Callable[[Tensor, Tensor], Tensor],
        optimizer: Optimizer,
        dataset: Dataset,
        reporter,
        device: str = "cpu",
        num_epochs: int = 500,
        eval_freq: int = 1,
        early_stopping: bool = True,
        model_cfg: dict = {},
    ):
        self.model = model
        self.lossfn = lossfn
        self.optimizer = optimizer
        self.data = dataset.data
        self.reporter = reporter
        if isinstance(reporter, ReporterInterface):
            self.reporter = [reporter]
        self.device = device
        self.num_epochs = num_epochs
        self.eval_freq = eval_freq
        self.early_stopping = early_stopping
        self.model.to(self.device)

        if 'digl_method' in model_cfg:
            print("Running GDC")
            print(self.data.edge_index)
            transform = GDC(
                diffusion_kwargs=dict(
                    method=model_cfg['digl_method'],
                    alpha=model_cfg['digl_alpha']
                ),
                sparsification_kwargs=dict(
                    method='threshold',
                    eps=model_cfg['digl_eps']
                ),
            )
            self.data = transform(self.data)
            print(self.data.edge_index)


    def report(self, metrics, primary_metric=None):
        if self.reporter is None:
            return
        for r in self.reporter:
            r(metrics, primary_metric)

    def print_status(self, i, metrics):
        print(self.template.format(i=i, **metrics))

    def reset_model(self):
        self.model.reset_parameters()

    def train(self):
        val_accs = []
        test_accs = []
        if len(self.data.train_mask.shape) == 2:
            for i in range(self.data.train_mask.shape[1]):
                # initialize a new version of the model each time
                self.reset_model()

                train_mask = self.data.train_mask[:, i]
                val_mask = self.data.val_mask[:, i]
                test_mask = self.data.test_mask[:, i]
                v, t = self._train_one(i, train_mask, val_mask, test_mask)
                val_accs.append(v)
                test_accs.append(t)

            metrics = {
                "mean_val_acc": np.mean(val_accs),
                "val_std": np.std(val_accs),
                "mean_test_acc": np.mean(test_accs),
                "test_std": np.std(test_accs),
            }
            self.report(metrics, primary_metric=("mean_val_acc", "maximize"))
            print(SUMMARY_TEMPLATE.format(**metrics))
        else:
            vacc, tacc = self._train_one(
                0, self.data.train_mask, self.data.val_mask, self.data.test_mask
            )
            metrics = {
                "mean_val_acc": vacc,
                "val_std": 0.0,
                "mean_test_acc": tacc,
                "test_std": 0.0,
            }
            self.report(metrics, primary_metric=("mean_val_acc", "maximize"))
            print(SUMMARY_TEMPLATE.format(**metrics))

        return metrics

    def _train_one(self, split, train_mask, val_mask, test_mask):
        x, edge_index, y = self.data.x, self.data.edge_index, self.data.y
        x = x.to(self.device)
        edge_index = edge_index.to(self.device)
        y = y.to(self.device)
        train_mask = train_mask.to(self.device)
        val_mask = val_mask.to(self.device)
        test_mask = test_mask.to(self.device)

        val_loss_history = []
        stop = False
        best_val_acc = 0.0
        best_test_acc = 0.0

        for i in range(self.num_epochs):
            trainloss, trainacc = self.train_step(x, edge_index, y, train_mask)
            vloss, vacc = self.valid_step(x, edge_index, y, val_mask)
            if i >= 10:
                val_loss_history.append(vloss)

            if self.early_stopping and i >= 50 and vloss > np.mean(val_loss_history):
                print("Early Stopping")
                stop = True

            if ((i + 1) % self.eval_freq == 0) or stop:
                testloss, testacc = self.test_step(x, edge_index, y, test_mask)
                metrics = {
                    "split": split,
                    "train_loss": trainloss,
                    "train_acc": trainacc,
                    "val_loss": vloss,
                    "val_acc": vacc,
                    "test_loss": testloss,
                    "test_acc": testacc,
                }
                self.report(metrics)
                print(SIMPLE_TEMPLATE.format(i=i, **metrics))
                if vacc > best_val_acc:
                    best_val_acc = vacc
                    best_test_acc = testacc
            if stop:
                break
        return best_val_acc, best_test_acc

    def accuracy(self, y, yhat):
        return float((y == yhat).sum() / int(len(y)))

    def train_step(self, x, edge_index, yin, train_mask):
        self.model.train()
        self.optimizer.zero_grad()
        if isinstance(self.model, GRAND):
            logits = self.model(x)[train_mask]
        else:
            logits = self.model(x, edge_index)[train_mask]

        y = yin[train_mask]
        loss = self.lossfn(logits, y)

        if isinstance(self.model, GRAND):
            self.model.fm.update(self.model.getNFE())
            self.model.resetNFE()
     
        loss.backward()
        self.optimizer.step()

        if isinstance(self.model, GRAND):
            self.model.fm.update(self.model.getNFE())
            self.model.resetNFE()

        acc = self.accuracy(logits.argmax(dim=1), y)
        return loss.item(), acc

    def eval_step(self, x, edge_index, yin, mask):
        self.model.eval()
        if isinstance(self.model, GRAND):
            logits = self.model(x)[mask]
        else:
            logits = self.model(x, edge_index)[mask]
        y = yin[mask]
        loss = self.lossfn(logits, y)
        acc = self.accuracy(logits.argmax(dim=1), y)
        return loss.item(), acc

    def valid_step(self, x, edge_index, yin, mask):
        return self.eval_step(x, edge_index, yin, mask)

    def test_step(self, x, edge_index, yin, mask):
        return self.eval_step(x, edge_index, yin, mask)