import torch
import torch.nn as nn
import numpy as np
from collections import OrderedDict


class NN(nn.Module):
    def __init__(self, layers=None, input_shape=None, device=None, dtype=None):
        super(NN, self).__init__()

        self.layers = nn.ModuleDict(layers) if layers is not None else nn.ModuleDict()

        self.to(device or self.device, dtype or self.dtype)
        if input_shape is not None:
            self.input_shape = input_shape
        elif isinstance(fl := next(iter(self.layers.values())), nn.Linear):
            self.input_shape = (fl.in_features,)
        else:
            raise ValueError("Input shape must be provided")

        self.to(device or self.device, dtype or self.dtype)

        self.trained_on = None

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @property
    def num_relus(self):
        return len([layer for layer in self.layers.values() if isinstance(layer, nn.ReLU)])

    def forward(self, data):
        x = data.reshape((-1,) + self.input_shape)
        for layer in self.layers.values():
            x = layer(x)
        return x

    def get_all_layer_outputs(self, data, layers=None, verbose=False):
        outputs = []
        x = data
        for name, layer in self.layers.items():
            if verbose:
                print(f"Layer {name}: {layer}")
            x = layer(x)
            if verbose:
                print(f"    Output shape: {x.shape}")
            if layers is None or name in layers:
                outputs.append((name, x))
        return OrderedDict(outputs)

    def get_grid(self, bounds=2, res=100):
        x = np.linspace(-bounds, bounds, res)
        y = np.copy(x)

        X, Y = np.meshgrid(x, y)

        X = np.reshape(X, -1)
        Y = np.reshape(Y, -1)

        inputVal = np.vstack((X, Y)).T
        return x, y, inputVal

    def output_grid(self, bounds=2, res=100):
        x, y, inputVal = self.get_grid(bounds, res)

        outs = self.get_all_layer_outputs(torch.Tensor(inputVal).to(self.device, self.dtype))

        return x, y, outs


def get_mlp_model(widths):
    layers = []
    for i in range(len(widths) - 1):
        layers.append((f"fc{i}", nn.Linear(widths[i], widths[i + 1])))
        if i < len(widths) - 2:
            layers.append((f"relu{i}", nn.ReLU()))
    net = NN(layers=OrderedDict(layers))
    net.widths = widths
    return net


def get_model(name, **kwargs):
    if name == "mnist_cnn":
        return NN(
            layers=OrderedDict(
                [
                    ("conv1", nn.Conv2d(1, 8, 3, 1)),
                    ("relu1", nn.ReLU()),
                    ("conv2", nn.Conv2d(8, 8, 3, 1)),
                    ("relu2", nn.ReLU()),
                    ("pool1", nn.AvgPool2d(2)),
                    ("conv3", nn.Conv2d(8, 8, 3, 1)),
                    ("relu3", nn.ReLU()),
                    ("pool2", nn.AvgPool2d(2)),
                    ("dropout1", nn.Dropout(0.25)),
                    ("flatten", nn.Flatten()),
                    ("fc1", nn.Linear(200, 150)),
                    ("relu4", nn.ReLU()),
                    ("fc2", nn.Linear(150, 150)),
                    ("relu5", nn.ReLU()),
                    ("fc3", nn.Linear(150, 32)),
                    ("relu6", nn.ReLU()),
                    ("dropout2", nn.Dropout(0.5)),
                    ("fc4", nn.Linear(32, 10)),
                ]
            ),
            input_shape=(1, 28, 28),
        )
    elif name == "mnist_fc":
        return get_mlp_model([784, 5, 8, 8, 8, 10])
    elif name == "cifar10_cnn":
        return NN(
            layers=OrderedDict(
                [
                    ("conv1", nn.Conv2d(3, 6, 5)),
                    ("pool1", nn.MaxPool2d(2, 2)),
                    ("relu1", nn.ReLU()),
                    ("conv2", nn.Conv2d(6, 16, 5)),
                    ("pool2", nn.MaxPool2d(2, 2)),
                    ("relu2", nn.ReLU()),
                    ("flatten", nn.Flatten()),
                    ("fc1", nn.Linear(16 * 5 * 5, 10)),
                    ("relu3", nn.ReLU()),
                    ("fc4", nn.Linear(10, 64)),
                    ("relu4", nn.ReLU()),
                    ("fC3", nn.Linear(64, 64)),
                    ("relu5", nn.ReLU()),
                    ("fc3", nn.Linear(64, 10)),
                ]
            ),
            input_shape=(3, 32, 32),
        )
    elif name == "cifar10_fc":
        return get_mlp_model([3 * 32 * 32, 64, 128, 128, 128, 10])
    elif name == "xor":
        return get_mlp_model([2, 2, 1])
    elif name == "circle":
        return get_mlp_model([2, 8, 8, 8, 1])
    elif name == "circle_shallow":
        return get_mlp_model([2, 16, 1])
    elif name == "alexnet":
        return torch.hub.load("pytorch/vision:v0.10.0", "alexnet", verbose=False)
    elif name == "california_housing":
        return get_mlp_model([8, 128, 128, 64, 1])
    elif name == "california_housing_shallow":
        return get_mlp_model([8, 128, 1])
    elif name == "california_housing_shallow_small":
        return get_mlp_model([8, 9, 1])
    elif name == "california_housing_reg":
        return get_mlp_model([8, 128, 1])
    elif name == "mlp":
        return get_mlp_model(**kwargs)
    else:
        raise ValueError(f"Unknown model: {name}")
