import torch
import torch.nn as nn
from transformers import (
    GPT2Model,
    GPT2Config,
    GPTNeoModel,
    AutoTokenizer,
    GPTNeoXModel,
    GPTNeoForCausalLM,
    AutoModelForCausalLM,
    GPTNeoConfig,
    AutoModelForCausalLM,
)
from typing import List, Tuple, Dict, Any, Optional
from tqdm import tqdm
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression, Lasso
import warnings
from sklearn import tree
from sklearn.decomposition import PCA
import xgboost as xgb
from sklearn.decomposition import PCA

from base_models import NeuralNetwork, ParallelNetworks
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoBlock
import numpy as np


def build_model(conf):
    if conf.family == "gpt2":
        model = TransformerModel(
            n_dims=conf.n_dims,
            n_positions=conf.n_positions,
            n_embd=conf.n_embd,
            n_layer=conf.n_layer,
            n_head=conf.n_head,
            n_y=conf.n_y,
            model_name=conf.model_name,
        )
    elif conf.family == "gpt-neo":
        model = TransformerLanguageModel(
            n_dims=conf.n_dims,
            n_positions=conf.n_positions,
            n_embd=conf.n_embd,
            n_layer=conf.n_layer,
            n_head=conf.n_head,
            n_y=conf.n_y,
            model_name=conf.model_name,
            lr_solver_head=conf.lr_solver_head,
        )
    else:
        raise NotImplementedError

    return model


def get_relevant_baselines(task_name):
    task_to_baselines = {
        "linear_regression": [
            (LeastSquaresModel, {}),
            (NNModel, {"n_neighbors": 3}),
            (AveragingModel, {}),
        ],
        "probabilistic_tanh": [
            (LeastSquaresModel, {}),
            (NNModel, {"n_neighbors": 3}),
            (AveragingModel, {}),
        ],
        "probabilistic_logistic_regression": [
            (LeastSquaresModel, {}),
            (NNModel, {"n_neighbors": 3}),
            (AveragingModel, {}),
        ],
        "crf": [
            (LeastSquaresModel, {}),
            (NNModel, {"n_neighbors": 3}),
            (AveragingModel, {}),
        ],
        "three_nodes": [
            (LeastSquaresModel, {}),
            (NNModel, {"n_neighbors": 3}),
            (AveragingModel, {}),
        ],
        "crf_ising": [
            (LeastSquaresModel, {}),
            (NNModel, {"n_neighbors": 3}),
            (AveragingModel, {}),
        ],
        "linear_classification": [
            (NNModel, {"n_neighbors": 3}),
            (AveragingModel, {}),
        ],
        "sparse_linear_regression": [
            (LeastSquaresModel, {}),
            (NNModel, {"n_neighbors": 3}),
            (AveragingModel, {}),
        ]
        + [(LassoModel, {"alpha": alpha}) for alpha in [1, 0.1, 0.01, 0.001, 0.0001]],
        "relu_2nn_regression": [
            (LeastSquaresModel, {}),
            (NNModel, {"n_neighbors": 3}),
            (AveragingModel, {}),
            (
                GDModel,
                {
                    "model_class": NeuralNetwork,
                    "model_class_args": {
                        "in_size": 20,
                        "hidden_size": 100,
                        "out_size": 1,
                    },
                    "opt_alg": "adam",
                    "batch_size": 100,
                    "lr": 5e-3,
                    "num_steps": 100,
                },
            ),
        ],
        "decision_tree": [
            (LeastSquaresModel, {}),
            (NNModel, {"n_neighbors": 3}),
            (DecisionTreeModel, {"max_depth": 4}),
            (DecisionTreeModel, {"max_depth": None}),
            (XGBoostModel, {}),
            (AveragingModel, {}),
        ],
    }

    models = [model_cls(**kwargs) for model_cls, kwargs in task_to_baselines[task_name]]
    return models


from transformers import GPTNeoPreTrainedModel, GPTNeoForCausalLM


class TransformerLanguageModel(nn.Module):
    def __init__(
        self,
        n_dims,
        n_positions,
        n_embd=128,
        n_layer=12,
        n_head=4,
        n_y=1,
        model_name="EleutherAI/gpt-neo-125M",
        adaptor=False,
        config=None,
        lr_solver_head=False,
    ):
        super(TransformerLanguageModel, self).__init__()
        print(model_name)
        self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"
        self.n_positions = n_positions
        self.n_dims = n_dims
        self._tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.positive_token_id_space = self._tokenizer(" positive").input_ids[0]
        self.negative_token_id_space = self._tokenizer(" negative").input_ids[0]

        self.adaptor = adaptor
        self.lr_solver_head = lr_solver_head
        if "pythia" in model_name:
            self._backbone = AutoModelForCausalLM.from_pretrained(
                model_name,
            )
            for param in self._backbone.gpt_neox.embed_in.parameters():
                param.requires_grad = False
            for param in self._backbone.embed_out.parameters():
                param.requires_grad = False
        elif "opt" in model_name:
            self._backbone = AutoModelForCausalLM.from_pretrained(
                model_name,
            )
            for param in self._backbone.model.decoder.embed_tokens.parameters():
                param.requires_grad = False
            for param in self._backbone.lm_head.parameters():
                param.requires_grad = False
        elif adaptor:
            clm = GPTNeoForCausalLM.from_pretrained(
                model_name,
            )
            config = GPTNeoConfig.from_pretrained(
                model_name,
            )

            self._backbone = clm.base_model
            self.lm_head = clm.lm_head
            for i, param in self._backbone.named_parameters():
                param.requires_grad = False
            # self.lm_head.requires_grad = False
            self.extra_layer_1 = GPTNeoBlock(config=config, layer_id=11)
            self.extra_layer_2 = GPTNeoBlock(config=config, layer_id=11)
            self.extra_layer_3 = GPTNeoBlock(config=config, layer_id=11)

        elif lr_solver_head:
            print("using lr solver head")
            self._backbone = GPTNeoForCausalLM.from_pretrained(
                model_name,
            )

            self.lr_head = TransformerModel(
                n_dims=self._backbone.config.hidden_size, n_positions=n_positions
            )
            # self.lr_head = TransformerModel(n_dims=self.n_dims, n_positions=n_positions)

            for param in self._backbone.transformer.wte.parameters():
                param.requires_grad = False
            for param in self._backbone.lm_head.parameters():
                param.requires_grad = False

            # freeze all layers in self._backbone
            for param in self._backbone.parameters():
                param.requires_grad = False

        else:
            print("loading GPT NEO MODEL")

            self._backbone = GPTNeoForCausalLM.from_pretrained(
                model_name,
            )
            for param in self._backbone.transformer.wte.parameters():
                param.requires_grad = False
            for param in self._backbone.lm_head.parameters():
                param.requires_grad = False

        self.y_step_size = n_y + 1
        self.n_y = n_y

    def _tokenize(self, batch):
        tok_batch = []
        for sequence in batch:
            input_seq = ""
            for sample in sequence:
                input_seq += " ".join([sample[0], sample[1]])
                input_seq += " "
            tokenized_seq = self._tokenizer(input_seq.strip()).input_ids
            tok_batch.append(tokenized_seq)
        return torch.tensor(tok_batch)

    def _normalize_norm(self, embeds):
        # take the mean of embed_i along axis=0

        mean = embeds.mean(dim=2, keepdim=True)
        std = embeds.std(dim=2, keepdim=True)

        # normalize the tensor by subtracting the mean and dividing by the standard deviation
        embeds = (embeds - mean) / std

        # calculate the L2 norm along the d-dimension
        norm = torch.norm(embeds, p=2, dim=2, keepdim=True)

        # normalize each vector by dividing it by its L2 norm
        embeds = embeds / norm
        return embeds

    def _normalize(self, embeds):
        embeds_white = []
        for s_i in range(embeds.shape[0]):
            embed_i = embeds[s_i, :, :]
            # take the mean of embed_i along axis=0
            embed_mean = torch.mean(embed_i, dim=0)
            embed_m0 = embed_i - embed_mean

            # compute covariance of embed_m0
            cov_embed = torch.matmul(embed_m0.T, embed_m0) / embed_m0.shape[0]
            # compute eigenvalues and eigenvectors of cov_embed
            eigenvalues, eigenvectors = torch.linalg.eigh(cov_embed)
            # get the diagonal matrix of 1/sqrt(eigenvalues)
            D = torch.diag(
                1 / torch.sqrt(eigenvalues.type(torch.float) + 1e-5)
            )  # np.diag(1.0 / np.sqrt(eigenvalues + 1e-5))
            embed_white = (
                eigenvectors.type(torch.float)
                @ D
                @ eigenvectors.type(torch.float).T
                @ embed_m0.T
            ).T
            embed_i = embed_i.cpu().numpy()
            embed_mean = embed_i.mean(axis=0)
            embed_m0 = embed_i - embed_mean

            embed_cov = np.cov(embed_m0, rowvar=False)
            eigenvalues, eigenvectors = np.linalg.eigh(embed_cov)

            #
            # eigenvalues_r = np.concatenate(
            #     [
            #         [0] * (len(eigenvalues) - embed_m0.shape[0]),
            #         eigenvalues[-embed_m0.shape[0] :],
            #     ]
            # )
            # D = np.diag([0 if v == 0 else 1 / np.sqrt(v) for v in eigenvalues_r])
            D = np.diag(1.0 / np.sqrt(eigenvalues + 1e-9))
            embed_white = (eigenvectors @ D @ eigenvectors.T @ embed_m0.T).T
            embeds_white.append(torch.Tensor(embed_white))
        # stack embeds_white along axis=0 as a torch tensor
        embeds_white = torch.stack(embeds_white, dim=0)

        return embed_white

    def _run_pca(self, embeds):
        embeds_list = []
        for b_i in range(embeds.shape[0]):
            embed_i = embeds[b_i]
            pca = PCA(n_components=self.n_dims)
            pca.fit(embed_i)
            X_tr_pca_cor = pca.transform(embed_i)

            X_tr_pca_cor_mean = X_tr_pca_cor.mean(axis=0)
            X_tr_pca_cor_m0 = X_tr_pca_cor - X_tr_pca_cor_mean

            cov_X_cor = np.cov(X_tr_pca_cor_m0, rowvar=False)
            eigenvalues, eigenvectors = np.linalg.eigh(cov_X_cor)
            D = np.diag(1.0 / np.sqrt(eigenvalues))
            X_tr_pca_cor_white = (
                eigenvectors @ D @ eigenvectors.T @ X_tr_pca_cor_m0.T
            ).T
            embeds_list.append(torch.Tensor(X_tr_pca_cor_white))
        embeds_list = torch.stack(embeds_list, dim=0)
        return embeds_list.cuda()

    def forward(
        self,
        xs,
        ys,
    ):
        # get size of xs (a torch tensor)
        if self.adaptor:
            embeds = self._backbone(
                input_ids=xs,
            )
            bl1 = self.extra_layer_1(embeds[0])
            bl2 = self.extra_layer_2(bl1[0])
            bl3 = self.extra_layer_3(bl2[0])
            output = self.lm_head(bl3[0])
        elif self.lr_solver_head:
            embeds = self._backbone(input_ids=xs, output_hidden_states=True)
            hidden_states = embeds.hidden_states[-1]
            # get all indices where ys != -100

            ys_b = torch.where(
                (xs == torch.Tensor([self.positive_token_id_space]).cuda().item())
                | (xs == torch.Tensor([self.negative_token_id_space]).cuda().item()),
                xs,
                torch.tensor([-100]).cuda(),
            )
            indices = torch.where(ys_b != -100)[1][
                0 : ys.shape[-1]
            ]  # [0 : self.n_positions]
            embed_xs = []
            for i in range(len(indices)):
                if i == 0:
                    embed = hidden_states[:, 0 : indices[i], :]
                    # average the embed sequence along dim=1
                    embed = torch.mean(embed, dim=1)
                    embed_xs.append(embed)

                else:
                    embed = hidden_states[:, indices[i - 1] + 1 : indices[i], :]
                    # average the embed sequence along dim=1
                    embed = torch.mean(embed, dim=1)
                    embed_xs.append(embed)
            embed_xs = torch.stack(embed_xs, dim=1)
            # create a new tensor, embed_ys where if ys == self.positive_token_id_space set to 1 else set to 0
            # embed_ys = torch.where(
            #     ys == self.positive_token_id_space, torch.tensor(1), torch.tensor(0)
            # )
            embed_white = self._normalize_norm(embed_xs)
            # embed_white = self._run_pca(embed_xs.cpu().numpy())
            # embed_white = embed_xs

            output = self.lr_head(embed_white, ys)
        else:
            output = self._backbone(input_ids=xs).logits

        return output  # embed_xs


class TransformerModel(nn.Module):
    def __init__(
        self,
        n_dims,
        n_positions,
        n_embd=128,
        n_layer=12,
        n_head=4,
        n_y=1,
        model_name=None,
    ):
        super(TransformerModel, self).__init__()
        configuration = GPT2Config(
            n_positions=(n_y + 1) * n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
            use_cache=False,
        )
        self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"

        self.n_positions = n_positions
        self.n_dims = n_dims
        self._read_in = nn.Linear(n_dims, n_embd)
        self._backbone = GPT2Model(configuration)
        self._read_out = nn.Linear(n_embd, 1)
        self.y_step_size = n_y + 1
        self.n_y = n_y

    @staticmethod
    def _combine(xs_b, ys_b):
        """Interleaves the x's and the y's into a single sequence."""
        bsize, points, dim = xs_b.shape
        ys_b_wide = torch.cat(
            (
                ys_b.view(bsize, points, 1),
                torch.zeros(bsize, points, dim - 1, device=ys_b.device),
            ),
            axis=2,
        )
        zs = torch.stack((xs_b, ys_b_wide), dim=2)
        zs = zs.view(bsize, 2 * points, dim)
        return zs

    def _combine_gen(self, xs_b, ys_b):
        bsize, points, dim = xs_b.shape
        ys_list = []
        for i in range(self.n_y):
            ys_b_i = ys_b[i, ::]
            ys_b_i_wide = torch.cat(
                (
                    ys_b_i.view(bsize, points, 1),
                    torch.zeros(bsize, points, dim - 1, device=ys_b.device),
                ),
                axis=2,
            )
            ys_list.append(ys_b_i_wide)
        zs = torch.stack((xs_b, *ys_list), dim=2)
        zs = zs.view(bsize, (self.n_y + 1) * points, dim)

        return zs

    def forward(self, xs, ys, inds=None):
        # get size of xs (a torch tensor)
        if len(ys.shape) > 2:
            inds = torch.arange(ys.shape[-1])
            zs = self._combine_gen(xs, ys)
            embeds = self._read_in(zs)

            output = self._backbone(
                inputs_embeds=embeds,  # attention_mask=attention_mask
            ).last_hidden_state
            prediction = self._read_out(output)

            preds_y = []
            for i in range(self.n_y):
                preds_y.append(prediction[:, i :: self.y_step_size, 0][:, inds])
            return preds_y

        else:
            if inds is None:
                inds = torch.arange(ys.shape[1])
            else:
                inds = torch.tensor(inds)
                if max(inds) >= ys.shape[1] or min(inds) < 0:
                    raise ValueError(
                        "inds contain indices where xs and ys are not defined"
                    )
            zs = self._combine(xs, ys)
            embeds = self._read_in(zs)
            output = self._backbone(inputs_embeds=embeds).last_hidden_state
            prediction = self._read_out(output)
            return prediction[:, ::2, 0][:, inds]  # predict only on xs


class NNModel:
    def __init__(self, n_neighbors, weights="uniform"):
        # should we be picking k optimally
        self.n_neighbors = n_neighbors
        self.weights = weights
        self.name = f"NN_n={n_neighbors}_{weights}"

    def __call__(self, xs, ys, inds=None):
        if inds is None:
            inds = range(ys.shape[1])
        else:
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")

        preds = []

        for i in inds:
            if i == 0:
                preds.append(torch.zeros_like(ys[:, 0]))  # predict zero for first point
                continue
            train_xs, train_ys = xs[:, :i], ys[:, :i]
            test_x = xs[:, i : i + 1]
            dist = (train_xs - test_x).square().sum(dim=2).sqrt()

            if self.weights == "uniform":
                weights = torch.ones_like(dist)
            else:
                weights = 1.0 / dist
                inf_mask = torch.isinf(weights).float()  # deal with exact match
                inf_row = torch.any(inf_mask, axis=1)
                weights[inf_row] = inf_mask[inf_row]

            pred = []
            k = min(i, self.n_neighbors)
            ranks = dist.argsort()[:, :k]
            for y, w, n in zip(train_ys, weights, ranks):
                y, w = y[n], w[n]
                pred.append((w * y).sum() / w.sum())
            preds.append(torch.stack(pred))

        return torch.stack(preds, dim=1)


# xs and ys should be on cpu for this method. Otherwise the output maybe off in case when train_xs is not full rank due to the implementation of torch.linalg.lstsq.
class LeastSquaresModel:
    def __init__(self, driver=None):
        self.driver = driver
        self.name = f"OLS_driver={driver}"

    def __call__(self, xs, ys, inds=None):
        xs, ys = xs.cpu(), ys.cpu()
        if inds is None:
            inds = range(ys.shape[1])
        else:
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")

        preds = []

        for i in inds:
            if i == 0:
                preds.append(torch.zeros_like(ys[:, 0]))  # predict zero for first point
                continue
            train_xs, train_ys = xs[:, :i], ys[:, :i]
            test_x = xs[:, i : i + 1]

            ws, _, _, _ = torch.linalg.lstsq(
                train_xs, train_ys.unsqueeze(2), driver=self.driver
            )

            pred = test_x @ ws
            preds.append(pred[:, 0, 0])

        return torch.stack(preds, dim=1)


class AveragingModel:
    def __init__(self):
        self.name = "averaging"

    def __call__(self, xs, ys, inds=None):
        if inds is None:
            inds = range(ys.shape[1])
        else:
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")

        preds = []

        for i in inds:
            if i == 0:
                preds.append(torch.zeros_like(ys[:, 0]))  # predict zero for first point
                continue
            train_xs, train_ys = xs[:, :i], ys[:, :i]
            test_x = xs[:, i : i + 1]

            train_zs = train_xs * train_ys.unsqueeze(dim=-1)
            w_p = train_zs.mean(dim=1).unsqueeze(dim=-1)
            pred = test_x @ w_p
            preds.append(pred[:, 0, 0])

        return torch.stack(preds, dim=1)


# Lasso regression (for sparse linear regression).
# Seems to take more time as we decrease alpha.
class LassoModel:
    def __init__(self, alpha, max_iter=100000):
        # the l1 regularizer gets multiplied by alpha.
        self.alpha = alpha
        self.max_iter = max_iter
        self.name = f"lasso_alpha={alpha}_max_iter={max_iter}"

    # inds is a list containing indices where we want the prediction.
    # prediction made at all indices by default.
    def __call__(self, xs, ys, inds=None):
        xs, ys = xs.cpu(), ys.cpu()

        if inds is None:
            inds = range(ys.shape[1])
        else:
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")

        preds = []  # predict one for first point

        # i: loop over num_points
        # j: loop over bsize
        for i in inds:
            pred = torch.zeros_like(ys[:, 0])

            if i > 0:
                pred = torch.zeros_like(ys[:, 0])
                for j in range(ys.shape[0]):
                    train_xs, train_ys = xs[j, :i], ys[j, :i]

                    # If all points till now have the same label, predict that label.

                    clf = Lasso(
                        alpha=self.alpha, fit_intercept=False, max_iter=self.max_iter
                    )

                    # Check for convergence.
                    with warnings.catch_warnings():
                        warnings.filterwarnings("error")
                        try:
                            clf.fit(train_xs, train_ys)
                        except Warning:
                            print(f"lasso convergence warning at i={i}, j={j}.")
                            raise

                    w_pred = torch.from_numpy(clf.coef_).unsqueeze(1)

                    test_x = xs[j, i : i + 1]
                    y_pred = (test_x @ w_pred.float()).squeeze(1)
                    pred[j] = y_pred[0]

            preds.append(pred)

        return torch.stack(preds, dim=1)


# Gradient Descent and variants.
# Example usage: gd_model = GDModel(NeuralNetwork, {'in_size': 50, 'hidden_size':400, 'out_size' :1}, opt_alg = 'adam', batch_size = 100, lr = 5e-3, num_steps = 200)
class GDModel:
    def __init__(
        self,
        model_class,
        model_class_args,
        opt_alg="sgd",
        batch_size=1,
        num_steps=1000,
        lr=1e-3,
        loss_name="squared",
    ):
        # model_class: torch.nn model class
        # model_class_args: a dict containing arguments for model_class
        # opt_alg can be 'sgd' or 'adam'
        # verbose: whether to print the progress or not
        # batch_size: batch size for sgd
        self.model_class = model_class
        self.model_class_args = model_class_args
        self.opt_alg = opt_alg
        self.lr = lr
        self.batch_size = batch_size
        self.num_steps = num_steps
        self.loss_name = loss_name

        self.name = f"gd_model_class={model_class}_model_class_args={model_class_args}_opt_alg={opt_alg}_lr={lr}_batch_size={batch_size}_num_steps={num_steps}_loss_name={loss_name}"

    def __call__(self, xs, ys, inds=None, verbose=False, print_step=100):
        # inds is a list containing indices where we want the prediction.
        # prediction made at all indices by default.
        # xs: bsize X npoints X ndim.
        # ys: bsize X npoints.
        xs, ys = xs.cuda(), ys.cuda()

        if inds is None:
            inds = range(ys.shape[1])
        else:
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")

        preds = []  # predict one for first point

        # i: loop over num_points
        for i in tqdm(inds):
            pred = torch.zeros_like(ys[:, 0])
            model = ParallelNetworks(
                ys.shape[0], self.model_class, **self.model_class_args
            )
            model.cuda()
            if i > 0:
                pred = torch.zeros_like(ys[:, 0])

                train_xs, train_ys = xs[:, :i], ys[:, :i]
                test_xs, test_ys = xs[:, i : i + 1], ys[:, i : i + 1]

                if self.opt_alg == "sgd":
                    optimizer = torch.optim.SGD(model.parameters(), lr=self.lr)
                elif self.opt_alg == "adam":
                    optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
                else:
                    raise NotImplementedError(f"{self.opt_alg} not implemented.")

                if self.loss_name == "squared":
                    loss_criterion = nn.MSELoss()
                else:
                    raise NotImplementedError(f"{self.loss_name} not implemented.")

                # Training loop
                for j in range(self.num_steps):
                    # Prepare batch
                    mask = torch.zeros(i).bool()
                    perm = torch.randperm(i)
                    mask[perm[: self.batch_size]] = True
                    train_xs_cur, train_ys_cur = train_xs[:, mask, :], train_ys[:, mask]

                    if verbose and j % print_step == 0:
                        model.eval()
                        with torch.no_grad():
                            outputs = model(train_xs_cur)
                            loss = loss_criterion(
                                outputs[:, :, 0], train_ys_cur
                            ).detach()
                            outputs_test = model(test_xs)
                            test_loss = loss_criterion(
                                outputs_test[:, :, 0], test_ys
                            ).detach()
                            print(
                                f"ind:{i},step:{j}, train_loss:{loss.item()}, test_loss:{test_loss.item()}"
                            )

                    optimizer.zero_grad()

                    model.train()
                    outputs = model(train_xs_cur)
                    loss = loss_criterion(outputs[:, :, 0], train_ys_cur)
                    loss.backward()
                    optimizer.step()

                model.eval()
                pred = model(test_xs).detach()

                assert pred.shape[1] == 1 and pred.shape[2] == 1
                pred = pred[:, 0, 0]

            preds.append(pred)

        return torch.stack(preds, dim=1)


class DecisionTreeModel:
    def __init__(self, max_depth=None):
        self.max_depth = max_depth
        self.name = f"decision_tree_max_depth={max_depth}"

    # inds is a list containing indices where we want the prediction.
    # prediction made at all indices by default.
    def __call__(self, xs, ys, inds=None):
        xs, ys = xs.cpu(), ys.cpu()

        if inds is None:
            inds = range(ys.shape[1])
        else:
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")

        preds = []

        # i: loop over num_points
        # j: loop over bsize
        for i in inds:
            pred = torch.zeros_like(ys[:, 0])

            if i > 0:
                pred = torch.zeros_like(ys[:, 0])
                for j in range(ys.shape[0]):
                    train_xs, train_ys = xs[j, :i], ys[j, :i]

                    clf = tree.DecisionTreeRegressor(max_depth=self.max_depth)
                    clf = clf.fit(train_xs, train_ys)
                    test_x = xs[j, i : i + 1]
                    y_pred = clf.predict(test_x)
                    pred[j] = y_pred[0]

            preds.append(pred)

        return torch.stack(preds, dim=1)


class XGBoostModel:
    def __init__(self):
        self.name = "xgboost"

    # inds is a list containing indices where we want the prediction.
    # prediction made at all indices by default.
    def __call__(self, xs, ys, inds=None):
        xs, ys = xs.cpu(), ys.cpu()

        if inds is None:
            inds = range(ys.shape[1])
        else:
            if max(inds) >= ys.shape[1] or min(inds) < 0:
                raise ValueError("inds contain indices where xs and ys are not defined")

        preds = []

        # i: loop over num_points
        # j: loop over bsize
        for i in tqdm(inds):
            pred = torch.zeros_like(ys[:, 0])
            if i > 0:
                pred = torch.zeros_like(ys[:, 0])
                for j in range(ys.shape[0]):
                    train_xs, train_ys = xs[j, :i], ys[j, :i]

                    clf = xgb.XGBRegressor()

                    clf = clf.fit(train_xs, train_ys)
                    test_x = xs[j, i : i + 1]
                    y_pred = clf.predict(test_x)
                    pred[j] = y_pred[0].item()

            preds.append(pred)

        return torch.stack(preds, dim=1)
