import copy

import torch
import torch.nn.functional as F
import torchdiffeq
from ogb.nodeproppred import Evaluator
from torchdiffeq._impl.dopri5 import _DORMAND_PRINCE_SHAMPINE_TABLEAU, DPS_C_MID
from torchdiffeq._impl.interp import _interp_evaluate
from torchdiffeq._impl.misc import _check_inputs, _flat_to_shape
from torchdiffeq._impl.rk_common import RKAdaptiveStepsizeODESolver, rk4_alt_step_func
from torchdiffeq._impl.solvers import FixedGridODESolver

from poi import PoissonIntegration


def run_evaluator(evaluator, data, y_pred):
    train_acc = evaluator.eval({
        'y_true': data.y[data.train_mask],
        'y_pred': y_pred[data.train_mask],
    })['acc']
    valid_acc = evaluator.eval({
        'y_true': data.y[data.val_mask],
        'y_pred': y_pred[data.val_mask],
    })['acc']
    test_acc = evaluator.eval({
        'y_true': data.y[data.test_mask],
        'y_pred': y_pred[data.test_mask],
    })['acc']
    return train_acc, valid_acc, test_acc


class EarlyStopDopri5(RKAdaptiveStepsizeODESolver):
    order = 5
    tableau = _DORMAND_PRINCE_SHAMPINE_TABLEAU
    mid = DPS_C_MID

    def __init__(self, func, y0, rtol, atol, opt, **kwargs):
        super(EarlyStopDopri5, self).__init__(func, y0, rtol, atol, **kwargs)

        self.lf = torch.nn.CrossEntropyLoss()
        self.m2 = None
        self.data = None
        self.best_val = 0
        self.best_test = 0
        self.best_time = 0
        self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test
        self.dataset = opt['dataset']
        if opt['dataset'] == 'ogbn-arxiv':
            self.lf = torch.nn.functional.nll_loss
            self.evaluator = Evaluator(name=opt['dataset'])

    def set_accs(self, train, val, test, time):
        self.best_train = train
        self.best_val = val
        self.best_test = test
        self.best_time = time.item()

    def integrate(self, t):
        solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
        solution[0] = self.y0
        t = t.to(self.dtype)
        self._before_integrate(t)
        new_t = t
        for i in range(1, len(t)):
            new_t, y = self.advance(t[i])
            solution[i] = y
        return new_t, solution

    def advance(self, next_t):
        """
        Takes steps dt to get to the next user specified time point next_t. In practice this goes past next_t and then interpolates
        :param next_t:
        :return: The state, x(next_t)
        """
        n_steps = 0
        while next_t > self.rk_state.t1:
            assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
            self.rk_state = self._adaptive_step(self.rk_state)
            n_steps += 1
            train_acc, val_acc, test_acc = self.evaluate(self.rk_state)
            if val_acc > self.best_val:
                self.set_accs(train_acc, val_acc, test_acc, next_t)
        new_t = next_t
        return (new_t, _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t))

    @torch.no_grad()
    def test(self, logits):
        accs = []
        for _, mask in self.data('train_mask', 'val_mask', 'test_mask'):
            pred = logits[mask].max(1)[1]
            acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()
            accs.append(acc)
        return accs

    @torch.no_grad()
    def test_OGB(self, logits):
        evaluator = self.evaluator
        data = self.data
        y_pred = logits.argmax(dim=-1, keepdim=True)
        train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred)
        return [train_acc, valid_acc, test_acc]

    @torch.no_grad()
    def evaluate(self, rkstate):
        # Activation.
        z = rkstate.y1
        if not self.m2.in_features == z.shape[1]:  # system has been augmented
            z = torch.split(z, self.m2.in_features, dim=1)[0]
        z = F.relu(z)
        z = self.m2(z)
        t0, t1 = float(self.rk_state.t0), float(self.rk_state.t1)
        if self.dataset == 'ogbn-arxiv':
            z = z.log_softmax(dim=-1)
            loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask])
        else:
            loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask])
        train_acc, val_acc, test_acc = self.ode_test(z)
        log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
        # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc))
        return train_acc, val_acc, test_acc

    def set_m2(self, m2):
        self.m2 = copy.deepcopy(m2)

    def set_data(self, data):
        if self.data is None:
            self.data = data


class EarlyStopRK4(FixedGridODESolver):
    order = 4

    def __init__(self, func, y0, rtol, atol, opt, eps=0, **kwargs):
        super(EarlyStopRK4, self).__init__(func, y0, step_size=opt['step_size'])
        self.eps = torch.as_tensor(eps, dtype=self.dtype, device=self.device)
        self.lf = torch.nn.CrossEntropyLoss()
        self.m2 = None
        self.data = None
        self.best_val = 0
        self.best_test = 0
        self.best_time = 0
        self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test
        self.dataset = opt['dataset']
        if opt['dataset'] == 'ogbn-arxiv':
            self.lf = torch.nn.functional.nll_loss
            self.evaluator = Evaluator(name=opt['dataset'])

    def _step_func(self, func, t, dt, y):
        return rk4_alt_step_func(func, t + self.eps, dt - 2 * self.eps, y)

    def set_accs(self, train, val, test, time):
        self.best_train = train
        self.best_val = val
        self.best_test = test
        self.best_time = time.item()

    def integrate(self, t):
        time_grid = self.grid_constructor(self.func, self.y0, t)
        assert time_grid[0] == t[0] and time_grid[-1] == t[-1]

        solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
        solution[0] = self.y0

        j = 1
        y0 = self.y0
        for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
            dy = self._step_func(self.func, t0, t1 - t0, y0)
            y1 = y0 + dy
            train_acc, val_acc, test_acc = self.evaluate(y1, t0, t1)
            if val_acc > self.best_val:
                self.set_accs(train_acc, val_acc, test_acc, t0)

            while j < len(t) and t1 >= t[j]:
                solution[j] = self._linear_interp(t0, t1, y0, y1, t[j])
                j += 1
            y0 = y1

        return t1, solution

    @torch.no_grad()
    def test(self, logits):
        accs = []
        for _, mask in self.data('train_mask', 'val_mask', 'test_mask'):
            pred = logits[mask].max(1)[1]
            acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()
            accs.append(acc)
        return accs

    @torch.no_grad()
    def test_OGB(self, logits):
        evaluator = self.evaluator
        data = self.data
        y_pred = logits.argmax(dim=-1, keepdim=True)
        train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred)
        return [train_acc, valid_acc, test_acc]

    @torch.no_grad()
    def evaluate(self, z, t0, t1):
        # Activation.
        if not self.m2.in_features == z.shape[1]:  # system has been augmented
            z = torch.split(z, self.m2.in_features, dim=1)[0]
        z = F.relu(z)
        z = self.m2(z)
        if self.dataset == 'ogbn-arxiv':
            z = z.log_softmax(dim=-1)
            loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask])
        else:
            loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask])
        train_acc, val_acc, test_acc = self.ode_test(z)
        log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
        # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc))
        return train_acc, val_acc, test_acc

    def set_m2(self, m2):
        self.m2 = copy.deepcopy(m2)

    def set_data(self, data):
        if self.data is None:
            self.data = data


SOLVERS = {
    'dopri5': EarlyStopDopri5,
    'rk4': EarlyStopRK4
}


class EarlyStopInt(torch.nn.Module):
    def __init__(self, t, opt, device=None):
        super(EarlyStopInt, self).__init__()
        self.device = device
        self.solver = None
        self.data = None
        self.m2 = None
        self.opt = opt
        self.t = torch.tensor([0, opt['earlystopxT'] * t], dtype=torch.float).to(self.device)

    def forward(self, func, y0, t, method=None, rtol=1e-7, atol=1e-9,
                adjoint_method="dopri5", adjoint_atol=1e-9, adjoint_rtol=1e-7, options=None):
        """Integrate a system of ordinary differential equations.

        Solves the initial value problem for a non-stiff system of first order ODEs:
            ```
            dy/dt = func(t, y), y(t[0]) = y0
            ```
        where y is a Tensor of any shape.

        Output dtypes and numerical precision are based on the dtypes of the inputs `y0`.

        Args:
            func: Function that maps a Tensor holding the state `y` and a scalar Tensor
                `t` into a Tensor of state derivatives with respect to time.
            y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
                have any floating point or complex dtype.
            t: 1-D Tensor holding a sequence of time points for which to solve for
                `y`. The initial time point should be the first element of this sequence,
                and each time must be larger than the previous time. May have any floating
                point dtype. Converted to a Tensor with float64 dtype.
            rtol: optional float64 Tensor specifying an upper bound on relative error,
                per element of `y`.
            atol: optional float64 Tensor specifying an upper bound on absolute error,
                per element of `y`.
            method: optional string indicating the integration method to use.
            options: optional dict of configuring options for the indicated integration
                method. Can only be provided if a `method` is explicitly set.
            name: Optional name for this operation.

        Returns:
            y: Tensor, where the first dimension corresponds to different
                time points. Contains the solved value of y for each desired time point in
                `t`, with the initial value `y0` being the first element along the first
                dimension.

        Raises:
            ValueError: if an invalid `method` is provided.
            TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
                an invalid dtype.
        """
        method = self.opt['method']
        assert method in ['rk4', 'dopri5'], "Only dopri5 and rk4 implemented with early stopping"

        ver = torchdiffeq.__version__[0] + torchdiffeq.__version__[2] + torchdiffeq.__version__[4]
        if int(ver) >= 22:  # '0.2.2'
            event_fn = None
            shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, self.t,
                                                                                                      rtol,
                                                                                                      atol, method,
                                                                                                      options,
                                                                                                      event_fn, SOLVERS)
        else:
            shapes, func, y0, t, rtol, atol, method, options = _check_inputs(func, y0, self.t, rtol, atol, method,
                                                                             options,
                                                                             SOLVERS)

        self.solver = SOLVERS[method](func, y0, rtol=rtol, atol=atol, opt=self.opt, **options)
        if self.solver.data is None:
            self.solver.data = self.data
        self.solver.m2 = self.m2
        t, solution = self.solver.integrate(t)
        if shapes is not None:
            solution = _flat_to_shape(solution, (len(t),), shapes)
        return solution


class EarlyStopIntPoi(EarlyStopInt):
    def __init__(self, *args, **kwargs):
        super(EarlyStopIntPoi, self).__init__(*args, **kwargs)

    def forward(self, *args, **kwargs):
        return PoissonIntegration(super(EarlyStopIntPoi, self).forward, *args, wrap_dim=1, **kwargs)