import torch
from torch import nn

import numpy as np


def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()


class FCN(nn.Module):
    def __init__(self, input_dim, hidden_dims=(16, 32), device="cpu", dtype=torch.float32):
        super(FCN, self).__init__()

        self.n_layers = len(hidden_dims)

        dims = [input_dim] + [d for d in hidden_dims]
        layers = []
        for i in range(self.n_layers):
            layers.append(nn.Linear(dims[i], dims[i + 1], device=device, dtype=dtype))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(dims[-1], 1, device=device, dtype=dtype))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def reset_parameters(self):
        self.model.apply(weight_reset)


def train_sample_classifier(p_samples, q_samples, layer_dims=(16, 32), lr=0.01,
                            weight_decay=1e-5, epochs=100, return_losses=False, model=None,
                            stop_eps=1e-4, stop_window=5, stop_early=True,
                            tensor_kwargs={"device": "cpu", "dtype": torch.float32}):
    """Trains a classifier to detect the likelihood that a sample is drawn from
    p(x) or q(x). The result will be a high score for samples from p(x) and a
    low score for samples from q(x)."""
    Np = p_samples.size(0)
    Nq = q_samples.size(0)
    p_samples = p_samples.reshape(Np, -1)
    q_samples = q_samples.reshape(Nq, -1)
    D = p_samples.size(-1)

    X_train = torch.cat([p_samples, q_samples])
    y_train = torch.cat([torch.ones(Np, **tensor_kwargs), torch.zeros(Nq, **tensor_kwargs)])

    # If the network was not provided, initialize it.
    if model is None:
        model = FCN(D, **tensor_kwargs)

    criterion = nn.BCEWithLogitsLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr,
                                 weight_decay=weight_decay)

    model.train()
    losses = []

    if stop_early:
        loss_grads = []
        A = np.stack([np.arange(stop_window), np.ones(stop_window)], axis=1)
        A_inv = np.linalg.inv(A.T.dot(A)).dot(A.T)  # Pseudo-inverse needed to calculate slope.

    for ep in range(epochs):
        shuffle_idx = torch.randperm(Np + Nq)

        output = model(X_train[shuffle_idx])
        loss = criterion(output.squeeze(), y_train[shuffle_idx])
        losses.append(loss.item())

        if stop_early:
            if len(losses) >= stop_window:
                params = A_inv.dot(np.array(losses[-stop_window:]))  # [a, b] for function y = ax + b
                if np.abs(params[0]) <= stop_eps:
                    break
                loss_grads.append(params[0])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if return_losses:
        return model, losses
    else:
        return model
