#!/usr/bin/python3
"""
FT-Transformer regression model.

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] Gorishniy Y, Rubachev I, Khrulkov V, Babenko A. Revisiting deep
        learning models for tabular data. Proc NeurIPS. (2021). URL:
        https://openreview.net/forum?id=i_Q1yrOegLY

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
from __future__ import annotations
import os
import pandas as pd
import torch
import torch.nn as nn
import torch_frame.nn as fnn  # type: ignore
from pandas.api.types import is_object_dtype  # type: ignore
from pathlib import Path
from tqdm import tqdm
from torch_frame import stype  # type: ignore
from torch_frame.data import DataLoader, Dataset, StatType  # type: ignore
from typing import Any, Dict, Final, List, NamedTuple, Union


class TabularResNet(fnn.ResNet):
    def __init__(
        self,
        target_name: str,
        col_stats: Dict[str, Dict[StatType, Any]],
        col_names_dict: Dict[stype, List[str]],
        channels: int = 256,
        num_layers: int = 4,
        lr: float = 1e-4,
        epochs: int = 100,
        batch_size: int = 512,
        random_state: int = 2025,
        device: str = "auto",
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            target_name: name of the target column.
            col_stats: a dictionary that maps column names into stats.
            col_names_dict: a dictionary that maps stype to a list of column
                names.
            channels: hidden channel dimensionality.
            num_layers: number of layers.
            lr: learning rate.
            epochs: number of training epochs.
            batch_size: batch size.
            random_state: random seed.
            device: device.
        """
        del kwargs
        self.target_name: Final[str] = target_name
        self.lr: Final[float] = lr
        self.epochs: Final[int] = epochs
        self.batch_size: Final[int] = batch_size
        self.seed: Final[int] = random_state
        self.__hparams: Final[Dict[str, Any]] = {
            "target_name": target_name,
            "col_stats": col_stats,
            "col_names_dict": col_names_dict,
            "channels": channels,
            "num_layers": num_layers,
            "lr": lr,
            "epochs": epochs,
            "batch_size": batch_size,
            "random_state": random_state
        }
        torch.manual_seed(self.seed)

        super(TabularResNet, self).__init__(
            channels=channels,
            out_channels=1,
            num_layers=num_layers,
            col_stats=col_stats,
            col_names_dict=col_names_dict,
            stype_encoder_dict={
                stype.categorical: fnn.EmbeddingEncoder(),
                stype.numerical: fnn.LinearEncoder()
            }
        )

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        if device != "auto":
            self.device = torch.device(device)
        self.to(self.device)
        self.col_stats: Dict[str, Dict[StatType, Any]] = col_stats
        self.__fit = False

    @torch.no_grad()
    def predict(self, x: List[NamedTuple]) -> List[float]:
        """
        Forward pass through the model.
        Input:
            x: a list of the B input designs, where B is the batch size.
        Returns:
            An array of the B predicted scores.
        """
        if not self.__fit:
            raise ValueError

        df = pd.DataFrame([pt.to_dict() for pt in x])  # type: ignore
        df.columns = df.columns.str.strip()
        columns = getattr(self, "columns", None)
        if columns is not None:
            df = df[columns]
        else:
            raise ValueError
        if self.target_name in df.columns:
            df = df.drop(columns=[self.target_name])

        # Impute missing values with the mean of the training data.
        for feature, col_stats in self.col_stats.items():
            if StatType.MEAN not in col_stats or feature not in df.columns:
                continue
            if pd.isna(df[feature]).all():
                df[feature] = col_stats[StatType.MEAN]

        for col in df.columns:
            if is_object_dtype(df[col]):
                df[col] = df[col].astype("category")
        tf = Dataset(df, col_to_stype=getattr(self, "col_to_stype", None))
        tf.materialize()
        y = self(tf.tensor_frame.to(self.device)).detach().cpu().numpy()
        return y.squeeze().tolist()

    def fit(self, tf: Dataset) -> TabularResNet:
        """
        Fit the model to the training data.
        Input:
            tf: a tabular dataset of the training data.
        Returns:
            The fitted model.
        """
        assert not self.__fit
        loss_fn = nn.MSELoss()
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        train_loader = DataLoader(tf, batch_size=self.batch_size, shuffle=True)
        self.train()
        self.col_stats = tf.col_stats
        with tqdm(range(self.epochs)) as pbar:
            for _ in pbar:
                loss_accum = 0.0
                total_count = 0
                for _tf in train_loader:
                    optimizer.zero_grad()
                    if torch.cuda.is_available():
                        _tf = _tf.to(self.device)
                    ypred = self(_tf)
                    loss = loss_fn(ypred.view(-1), _tf.y.view(-1))
                    loss.backward()
                    loss_accum += float(loss) * len(_tf.y)
                    total_count += len(_tf.y)
                    optimizer.step()
                pbar.set_postfix(train_loss=(loss_accum / total_count))
        self.__fit = True
        self.columns = tf.df.columns.str.strip()
        self.col_to_stype = tf.col_to_stype
        if tf.target_col is not None:
            self.columns = self.columns.drop([tf.target_col])
            self.col_to_stype.pop(tf.target_col)
        return self.eval()

    def save(self, path: Union[Path, str]) -> None:
        """
        Saves the fitted model file to a local path.
        Input:
            path: the path to save the fitted model to.
        Returns:
            None.
        """
        assert self.__fit
        os.makedirs(os.path.dirname(path), exist_ok=True)
        data = {
            "state_dict": self.state_dict(),
            "hparams": self.__hparams,
            "columns": self.columns,
            "col_to_stype": self.col_to_stype
        }
        return torch.save(data, path)

    @classmethod
    def load(
        cls, path: Union[Path, str], device: str = "auto"
    ) -> TabularResNet:
        """
        Loads the fitted model file from a local path.
        Input:
            path: the path to load the fitted model from.
            device: the device to load the model to.
        Returns:
            The loaded model.
        """
        assert os.path.exists(path)
        data = torch.load(path, weights_only=False, map_location="cpu")
        model = cls(**data["hparams"], device=device)
        model.load_state_dict(data["state_dict"])
        model.__fit = True
        model.columns = data["columns"]
        model.col_to_stype = data["col_to_stype"]
        for parameter in model.parameters():
            parameter.requires_grad = False
        return model.eval()
