import os
import numpy as np
import torch
import torch.nn as nn
from sklearn.kernel_ridge import KernelRidge
from tqdm import tqdm

# import wandb

if os.environ.get("GPU"):
    device = os.environ.get("GPU") if torch.cuda.is_available() else "cpu"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_default_dtype(torch.float64)


class InnerNode:
    def __init__(self, config, depth, asym=False):
        self.config = config
        self.leaf = False
        self.fc = nn.Linear(
            self.config["input_dim"], self.config["n_tree"], bias=False
        ).to(device)
        nn.init.normal_(self.fc.weight, 0.0, 1.0)  # mean: 0.0, std: 1.0
        self.prob = None
        self.path_prob = None
        self.left = None
        self.right = None
        self.leaf_accumulator = []
        self.asym = asym

        self.build_child(depth)

    def build_child(self, depth):
        if depth < self.config["max_depth"]:
            self.left = InnerNode(self.config, depth + 1, asym=self.asym)
            if self.asym:
                self.right = LeafNode(self.config)
            else:
                self.right = InnerNode(self.config, depth + 1, asym=self.asym)
        else:
            self.left = LeafNode(self.config)
            self.right = LeafNode(self.config)

    def forward(self, x):  # decision function
        return (
            0.5 * torch.erf(self.config["scale"] * self.fc(x)) + 0.5
        )  # -> [batch_size, n_tree]

    def calc_prob(self, x, path_prob):
        self.prob = self.forward(x)  # probability of selecting right node
        path_prob = path_prob.to(device)  # path_prob: [batch_size, n_tree]
        self.path_prob = path_prob
        left_leaf_accumulator = self.left.calc_prob(x, path_prob * (1 - self.prob))
        right_leaf_accumulator = self.right.calc_prob(x, path_prob * self.prob)
        self.leaf_accumulator.extend(left_leaf_accumulator)
        self.leaf_accumulator.extend(right_leaf_accumulator)
        return self.leaf_accumulator

    def reset(self):
        self.leaf_accumulator = []
        self.penalties = []
        self.left.reset()
        self.right.reset()


class LeafNode:
    def __init__(self, config):
        self.config = config
        self.leaf = True
        self.param = nn.Parameter(
            torch.randn(self.config["output_dim"], self.config["n_tree"]).to(device)
        )  # [n_class, n_tree]
        # self.param.requires_grad = False  # Freeze

    def forward(self):
        return self.param

    def calc_prob(self, x, path_prob):
        path_prob = path_prob.to(device)  # [batch_size, n_tree]

        Q = self.forward()
        Q = Q.expand(
            (path_prob.size()[0], self.config["output_dim"], self.config["n_tree"])
        )  # -> [batch_size, n_class, n_tree]
        return [[path_prob, Q]]

    def reset(self):
        pass


class SoftTree(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        max_depth: int,
        scale: float,
        n_tree: int,
        asym: bool = False,
    ):
        super(SoftTree, self).__init__()
        config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "max_depth": max_depth,
            "scale": scale,
            "n_tree": n_tree,
        }
        self.config = config
        self.root = InnerNode(config, depth=1, asym=asym)
        self.collect_parameters()

    def collect_parameters(self):
        nodes = [self.root]
        self.module_list = nn.ModuleList()
        self.param_list = nn.ParameterList()
        while nodes:
            node = nodes.pop(0)
            if node.leaf:
                param = node.param
                self.param_list.append(param)
            else:
                fc = node.fc
                nodes.append(node.right)
                nodes.append(node.left)
                self.module_list.append(fc)

    def forward(self, x):
        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config["input_dim"])

        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config["n_tree"]))

        leaf_accumulator = self.root.calc_prob(x, path_prob_init)
        pred = torch.zeros(x.shape[0], self.config["output_dim"]).to(device)
        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop
            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)

        pred /= np.sqrt(self.config["n_tree"])  # NTK scaling

        self.root.reset()
        return pred


def finite_soft_trees(
    X1: np.array,
    X2: np.array,
    y1: np.array,
    y2: np.array,
    alpha: float,
    lr: float,
    depth: int,
    classes: int,
    asym: bool = False,
) -> float:
    n_val = len(X2)
    one_hot_label = torch.Tensor(np.eye(classes)[y1] - 1.0 / classes).to(device)
    X1, X2, y1, y2 = (
        torch.Tensor(X1).to(device),
        torch.Tensor(X2).to(device),
        torch.Tensor(y1).to(device),
        torch.Tensor(y2).to(device),
    )
    model = SoftTree(
        input_dim=len(X1[0]),
        output_dim=classes,
        max_depth=depth,
        scale=alpha,
        n_tree=512,
        asym=asym,
    )
    criterion = torch.nn.modules.loss.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    for epoch in tqdm(range(2000), leave=False, desc="epoch-loop"):
        outputs = model.forward(X1)
        loss = criterion(outputs, one_hot_label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        z = model.forward(X2).argmax(axis=1)
        # wandb.log({"tmp": 1.0 * torch.sum(z == y2).tolist() / n_val})
    return 1.0 * torch.sum(z == y2).tolist() / n_val


# ----------


def precomputed_kernel_ridge_regression(
    K1: np.array, K2: np.array, y1: np.array, y2: np.array, alpha: float, classes: int
) -> float:
    n_val, n_train = K2.shape
    clf = KernelRidge(
        kernel="precomputed",
        alpha=alpha,
    )
    one_hot_label = np.eye(classes)[y1] - 1.0 / classes
    clf.fit(K1, one_hot_label)
    z = clf.predict(K2).argmax(axis=1)
    return 1.0 * np.sum(z == y2) / n_val


def kernel_ridge_regression(
    X1: np.array,
    X2: np.array,
    y1: np.array,
    y2: np.array,
    alpha: float,
    gamma: float,
    classes: int,
    kernel: str,
) -> float:
    assert kernel in ("rbf", "laplacian")
    n_val = len(X2)
    clf = KernelRidge(
        kernel=kernel,
        alpha=alpha,
        gamma=gamma,
    )
    one_hot_label = np.eye(classes)[y1] - 1.0 / classes
    clf.fit(X1, one_hot_label)
    z = clf.predict(X2).argmax(axis=1)
    return 1.0 * np.sum(z == y2) / n_val
