from typing import List

import joblib
import torch
from torch import nn

from models.LinearModel import LinearModel
from models.regressors.MeanRegressor import MeanRegressor, NetworkLearningMeanRegressor
from models.networks import BaseModel
from xgboost import XGBRegressor  # changed import

class XGBoostRegressor(MeanRegressor):  # renamed class for clarity

    def __init__(self, dataset_name, saved_models_path, x_dim, z_dim, device='cpu',
                 figures_dir=None, seed=0):
        super().__init__(dataset_name, saved_models_path, figures_dir=figures_dir, seed=seed)
        # instantiate the XGBRegressor with a set number of estimators and seed for reproducibility
        self.model = XGBRegressor(random_state=seed)
        self.device = device

    def fit_xy_aux(self, x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs=1000, batch_size=64,
                   n_wait=20,
                   **kwargs):
        new_x_train = torch.cat([x_train, x_val], dim=0).cpu().numpy()
        new_y_train = torch.cat([y_train, y_val], dim=0).cpu().numpy()
        self.model.fit(new_x_train, new_y_train)

    def fit(self, x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            z_train=None, z_val=None, **kwargs):
        new_x_train = torch.cat([z_train, x_train], dim=-1)
        new_x_val = torch.cat([z_val, x_val], dim=-1)

        self.fit_xy(new_x_train, y_train, deleted_train, new_x_val, y_val, deleted_val, epochs=epochs, batch_size=batch_size,
                    n_wait=n_wait, **kwargs)

    def predict_mean(self, x, z, **kwargs):
        new_x = torch.cat([z, x], dim=-1)
        model_output = self.model.predict(new_x.cpu().numpy())
        model_output = torch.from_numpy(model_output).to(self.device)
        return model_output

    def eval(self):
        pass

    def store_model(self):
        joblib.dump(self.model, self.get_model_save_path())

    def load_model(self):
        self.model = joblib.load(self.get_model_save_path())

    @property
    def name(self) -> str:
        return "xgb_reg"
