import itertools
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger

torch.set_default_dtype(torch.float64)


class BaseMLP(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        depth,
        alpha,
        n_tree,
    ):
        super(BaseMLP, self).__init__()
        self.config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "depth": depth,
            "alpha": alpha,
            "n_tree": n_tree,
        }
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, n_tree))
        for i in range(1, depth):
            self.layers.append(nn.Linear(n_tree, n_tree))
        self.layers.append(nn.Linear(n_tree, output_dim))

    def forward(self, x):
        x = torch.squeeze(x, 1).reshape(x.shape[0], -1)  # Adjusted for general case
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))
        x = self.layers[-1](x)
        return x

    def activation(self, x):
        raise NotImplementedError("This method should be implemented by subclasses")


class ReLUMLP(BaseMLP):
    def activation(self, x):
        return F.relu(x)


class SigmoidMLP(BaseMLP):
    def activation(self, x):
        return F.sigmoid(x)


if __name__ == "__main__":
    input_dim = 2
    output_dim = 1
    depth = 3
    alpha = 1.0
    n_tree = 100

    device = "cuda" if torch.cuda.is_available() else "cpu"
    x = torch.Tensor([[1.0, 1.0]]).to(device)

    for depth in range(1, 6, 1):
        model = ReLUMLP(input_dim, output_dim, depth, alpha, n_tree).to(device)
        print(model)
        print(model.forward(x))
        model = SigmoidMLP(input_dim, output_dim, depth, alpha, n_tree).to(device)
        print(model)
        print(model.forward(x))

    # average input dim: (419+7+24+10+50+7+8+10+10+20+7+20+22+16+54+26)/16
    for depth in range(1, 4, 1):
        model = ReLUMLP(44, 2, depth, alpha, 256).to(device)
        print(sum(p.numel() for p in model.parameters()))
