import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy


class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.input_map = nn.Linear(input_size, 4 * hidden_size, bias=True)
        self.recurrent_map = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.init_weights()

    def init_weights(self):
        for w in self.input_map.parameters():
            if w.dim() == 1:
                torch.nn.init.uniform_(w, -0.1, 0.1)
            else:
                torch.nn.init.xavier_uniform_(w)
        for w in self.recurrent_map.parameters():
            if w.dim() == 1:
                torch.nn.init.uniform_(w, -0.1, 0.1)
            else:
                torch.nn.init.orthogonal_(w)

    def forward(self, inputs, states):
        output_state, cell_state = states

        z = self.input_map(inputs) + self.recurrent_map(output_state)
        i, ig, fg, og = z.chunk(4, 1)

        input_activation = self.tanh(i)
        input_gate = self.sigmoid(ig)
        forget_gate = self.sigmoid(fg + 3.0)
        output_gate = self.sigmoid(og)

        new_cell = cell_state * forget_gate + input_activation * input_gate
        output_state = self.tanh(new_cell) * output_gate

        return output_state, new_cell

# mmRNNCell uses an LSTM as memory element and a neural ODE for the time-continuous pathway
class mmRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size, solver_type="dopri5"):
        super(mmRNNCell, self).__init__()
        self.solver_type = solver_type
        self.fixed_step_solver = solver_type.startswith("fixed_")
        self.lstm = LSTMCell(input_size, hidden_size)
        # 1 hidden layer NODE
        self.f_node = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
        )
        self.input_size = input_size
        self.hidden_size = hidden_size
        if not self.fixed_step_solver:
            # Only import this complex module if really needed
            from torchdyn.models import NeuralDE

            self.node = NeuralDE(self.f_node, solver=solver_type)
        else:
            options = {
                "fixed_euler": self.euler,
                "fixed_heun": self.heun,
                "fixed_rk4": self.rk4,
            }
            if not solver_type in options.keys():
                raise ValueError("Unknown solver type '{:}'".format(solver_type))
            self.node = options[self.solver_type]

    def forward(self, input, hx, ts):
        new_h, new_c = self.lstm(input, hx)
        if self.fixed_step_solver:
            new_h = self.solve_fixed(new_h, ts)
        else:
            indices = torch.argsort(ts)
            batch_size = ts.size(0)
            device = input.device
            s_sort = ts[indices]
            s_sort = s_sort + torch.linspace(0, 1e-4, batch_size, device=device)
            # HACK: Make sure no two points are equal
            trajectory = self.node.trajectory(new_h, s_sort)
            new_h = trajectory[indices, torch.arange(batch_size, device=device)]

        return (new_h, new_c)

    def solve_fixed(self, x, ts):
        ts = ts.view(-1, 1)
        for i in range(3):  # 3 unfolds
            x = self.node(x, ts * (1.0 / 3))
        return x

    def euler(self, y, delta_t):
        dy = self.f_node(y)
        return y + delta_t * dy

    def heun(self, y, delta_t):
        k1 = self.f_node(y)
        k2 = self.f_node(y + delta_t * k1)
        return y + delta_t * 0.5 * (k1 + k2)

    def rk4(self, y, delta_t):
        k1 = self.f_node(y)
        k2 = self.f_node(y + k1 * delta_t * 0.5)
        k3 = self.f_node(y + k2 * delta_t * 0.5)
        k4 = self.f_node(y + k3 * delta_t)

        return y + delta_t * (k1 + 2 * k2 + 2 * k3 + k4) / 6.0


class mmRNN(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_size,
        out_feature,
        return_sequences=True,
        solver_type="odpri5",
    ):
        super(mmRNN, self).__init__()
        self.in_features = in_features
        self.hidden_size = hidden_size
        self.out_feature = out_feature
        self.return_sequences = return_sequences

        self.rnn_cell = mmRNN(in_features, hidden_size, solver_type=solver_type)
        self.fc = nn.Linear(self.hidden_size, self.out_feature)

    def forward(self, x, timespans, mask=None):
        device = x.device
        batch_size = x.size(0)
        seq_len = x.size(1)
        hidden_state = [
            torch.zeros((batch_size, self.hidden_size), device=device),
            torch.zeros((batch_size, self.hidden_size), device=device),
        ]
        outputs = []
        last_output = torch.zeros((batch_size, self.out_feature), device=device)
        for t in range(seq_len):
            inputs = x[:, t]
            ts = timespans[:, t].squeeze()
            hidden_state = self.rnn_cell.forward(inputs, hidden_state, ts)
            current_output = self.fc(hidden_state[0])
            outputs.append(current_output)
            if mask is not None:
                cur_mask = mask[:, t].view(batch_size, 1)
                last_output = cur_mask * current_output + (1.0 - cur_mask) * last_output
            else:
                last_output = current_output
        if self.return_sequences:
            outputs = torch.stack(outputs, dim=1)  # return entire sequence
        else:
            outputs = last_output  # only last item
        return outputs


class IrregularSequenceLearner(pl.LightningModule):
    def __init__(self, model, lr=0.005, opt="adam", is_regression=False):
        super().__init__()
        self.model = model
        self._opt_name = opt
        self.is_regression = is_regression
        self.lr = lr

    def training_step(self, batch, batch_idx):
        if len(batch) == 4:
            x, t, y, mask = batch
        else:
            x, t, y = batch
            mask = None
        y_hat = self.model.forward(x, t, mask)
        y_hat = y_hat.view(-1, y_hat.size(-1))
        if self.is_regression:
            y = y.view(y_hat.size())  # Same shape
            loss = nn.MSELoss()(y_hat, y)
            self.log("train_loss", loss, prog_bar=True)
            return {"loss": loss}
        else:
            y = y.view(-1)  # Flatten labels
            loss = nn.CrossEntropyLoss()(y_hat, y)
            preds = torch.argmax(y_hat.detach(), dim=-1)
            acc = accuracy(preds, y)
            self.log("train_acc", acc, prog_bar=True)
            self.log("train_loss", loss, prog_bar=True)
            return {"loss": loss, "acc": acc}

    def validation_step(self, batch, batch_idx):
        if len(batch) == 4:
            x, t, y, mask = batch
        else:
            x, t, y = batch
            mask = None
        y_hat = self.model.forward(x, t, mask)
        y_hat = y_hat.view(-1, y_hat.size(-1))
        if self.is_regression:
            y = y.view(y_hat.size())  # Same shape
            loss = nn.MSELoss()(y_hat, y)
            self.log("val_loss", loss, prog_bar=True)
            return loss
        else:
            y = y.view(-1)
            loss = nn.CrossEntropyLoss()(y_hat, y)
            preds = torch.argmax(y_hat, dim=1)
            acc = accuracy(preds, y)
            self.log("val_loss", loss, prog_bar=True)
            self.log("val_acc", acc, prog_bar=True)
            return acc

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        if self._opt_name == "adam":
            return torch.optim.Adam(self.model.parameters(), lr=self.lr)
        elif self._opt_name == "rmsprop":
            return torch.optim.RMSprop(self.model.parameters(), lr=self.lr)
        else:
            raise ValueError("Unknown optimizer: " + str(self._opt_name))
