# Feed-forward network[s].

import math
import statistics
from pathlib import Path
from typing import Any

import delu
from joblib import Memory
import numpy as np
import rtdl_num_embeddings
import torch
import torch.nn as nn
import torch.utils.tensorboard
import torchvision.transforms as transforms 
from loguru import logger
from torch import Tensor
from tqdm import tqdm
from typing_extensions import NotRequired, TypedDict
import xgboost as xgb

import lib
import lib.data
import lib.deep
from lib import KWArgs, PartKey

memory = Memory(location='./joblib_tmp', verbose=2)

# Tree to Vector
class TreeToVector: 
    def __init__(self, xgbTree, dtype=torch.float): 
        self.xgbTree = xgbTree 
        self.dtype = dtype
        
    def __call__(self, tensor): 
        output = self.tree_encoder(tensor)
        return output 
    
    def tree_encoder(self, tensor): 
        # fill nan with -1 
        tensor = torch.nan_to_num(tensor, nan=-1.0)
        output = self.postprocessing(tensor, self.xgbTree.multiply_matrix, self.xgbTree.offset_vector)
        return output 
        
    def postprocessing(self, x, multiply_matrix, offset_vector): 
        device = x.device
        multiply_matrix = multiply_matrix.to(device)
        offset_vector = offset_vector.to(device)
        
        x = torch.matmul(x, multiply_matrix)
        x -= offset_vector
        x[x > 0] = 1.0
        x[x < 0] = 0.0
        return x
    
    
class TreeToVectorConverter:
    def __init__(self, xgb_model, num_variable): 
        self.df_trees = self.load_xgb_model(xgb_model)
        self.num_variable = num_variable
        self.num_tree = self.df_trees['Tree'].nunique()
        self.tree_dict = self.get_tree_dict()
        self.num_encode = self.get_num_encode()
        self.multiply_matrix, self.offset_vector = self.get_encode_element()
        
    def load_xgb_model(self, xgb_model): 
        df_trees = xgb_model.trees_to_dataframe()
        return df_trees
    
    def get_tree_dict(self): 
        tree_list = extract_trees(self.df_trees)
        tree_dict = self.dedup_tree(tree_list)
        return tree_dict 
    
    # get compact tree with duplicated threshold removed (given precision)
    def dedup_tree(self, trees, precision=2): 
        compact_trees = {}
        for tree in trees: 
            for variable in tree.keys(): 
                split_val = tree[variable]
                # only keep the first x precision
                split_val = [round(item, precision) for item in split_val]
                # add to global dict 
                if variable not in compact_trees: 
                    compact_trees[variable] = split_val
                else:
                    compact_trees[variable] = compact_trees[variable] + split_val
        # remove duplicates 
        for key in compact_trees: 
            compact_trees[key] = list(set(compact_trees[key]))
        return compact_trees
    
    def get_num_encode(self): 
        num_encode = 0 
        for key in self.tree_dict: 
            num_encode += len(self.tree_dict[key])
        return num_encode 
    
    def get_encode_element(self): 
        m, d = self.num_variable, self.num_encode
        
        multiply_matrix = torch.zeros((m, d))
        offset_vector = torch.zeros((1, d))
        
        col_idx = 0
        for feature in self.tree_dict: 
            # row idx 
            idx = int(feature[1:])
            for split in self.tree_dict[feature]:
                multiply_matrix[idx, col_idx] = 1.0
                offset_vector[0, col_idx] = split
                col_idx += 1
        return multiply_matrix, offset_vector

    
# Tree to Token
class TreeToToken: 
    def __init__(self, xgbTree, dtype=torch.float): 
        self.xgbTree = xgbTree 
        self.dtype = dtype
    
    def __call__(self, tensor): 
        num_data = tensor.size()[0]
        num_tree = self.xgbTree.num_tree
        num_encode = self.xgbTree.num_encode
        output = torch.zeros((num_data, num_tree, num_encode), dtype=self.dtype)
        output = self.tree_encoder(tensor, output)
        return output 
    
    def tree_encoder(self, tensor, output): 
        # fill nan with -1 
        tensor = torch.nan_to_num(tensor, nan=-1.0)
        for i in range(self.xgbTree.num_tree): 
            multiply_matrix = self.xgbTree.multiply_list[i]
            offset_vector   = self.xgbTree.offset_list[i]
            padding_vector  = self.xgbTree.padding_list[i]
            x_encode = self.postprocessing(tensor, multiply_matrix, offset_vector, padding_vector)
            output[:, i, :] = x_encode 
        return output 
    
    def postprocessing(self, x, multiply_matrix, offset_vector, padding_vector): 
        device = x.device
        multiply_matrix = multiply_matrix.to(device)
        offset_vector = offset_vector.to(device)
        padding_vector = padding_vector.to(device)
        
        x = torch.matmul(x, multiply_matrix)
        x -= offset_vector
        x[x > 0] = 1.0
        x[x < 0] = 0.0
        x += padding_vector
        return x 

    
class TreeToTokenConverter:
    def __init__(self, xgb_model_path, num_variable): 
        self.df_trees = self.load_xgb_model(xgb_model_path)
        self.num_variable = num_variable
        self.num_tree = self.df_trees['Tree'].nunique()
        self.num_encode = self.get_num_encode()
        self.tree_list = self.get_tree_list()
        self.multiply_list, self.offset_list, self.padding_list = self.get_encode_element()
    
    def load_xgb_model(self, model_path): 
        xgb_model = xgb.Booster()
        xgb_model.load_model(model_path)
        df_trees = xgb_model.trees_to_dataframe()
        return df_trees
    
    def get_tree_list(self): 
        tree_list = []
        for i in range(self.num_tree): 
            df_tree = self.df_trees.loc[self.df_trees['Tree']==i]
            df_tree.reset_index(inplace=True)
            tree_list.append(df_tree)
        return tree_list    
    
    def get_num_encode(self): 
        num_encode = 0 
        for i in range(self.num_tree):
            df_tree = self.df_trees.loc[self.df_trees['Tree']==i]
            num_encode = max(num_encode, len(df_tree))
        return num_encode 

    def get_encode_element(self): 
        multiply_list, offset_list, padding_list = [], [], []
        for tree in self.tree_list: 
            multiply_matrix, offset_vector, padding_vector = self.one_tree_encoder(tree)
            multiply_list.append(multiply_matrix)
            offset_list.append(offset_vector)
            padding_list.append(padding_vector)
        return multiply_list, offset_list, padding_list
    
    def one_tree_encoder(self, tree): 
        m, d = self.num_variable, self.num_encode
        
        multiply_matrix = torch.zeros((m, d))
        offset_vector = torch.zeros((1, d))
        padding_vector = torch.zeros((1, d))
        
        for i in range(len(tree)):
            feature = tree['Feature'][i]
            # update multiply and offset
            if feature != 'Leaf': 
                idx = int(tree['Feature'][i][1:])
                split = float(tree['Split'][i])
                multiply_matrix[idx, i] = 1.0
                offset_vector[0, i] = split
            # update padding vector 
            elif feature == 'Leaf': 
                padding_vector[0, i] = 0.5 
        # update padding vector 
        for i in range(len(tree), d): 
            padding_vector[0, i] = -1 
        return multiply_matrix, offset_vector, padding_vector 
    
    
# helper function 
def extract_trees(xgb_model):
    """ extract trees from XGB for binary classification cases
    Input: XGB model 
    Output: list of dict (tree)
    """
    trees = []
    #df = xgb_model._Booster.trees_to_dataframe()
    #df = xgb_model.trees_to_dataframe()
    df = xgb_model
    num_tree = df['Tree'].nunique()
    for i in range(num_tree): 
        tree_temp = {}
        mask = (df['Tree']==i) & (df['Feature']!='Leaf')
        df_temp = df.loc[mask, ['Feature', 'Split']] 
        for _, row in df_temp.iterrows(): 
            feature, split = row['Feature'], row['Split']
            if feature not in tree_temp:
                tree_temp[feature] = [split]
            else:
                tree_temp[feature].append(split)
        trees.append(tree_temp)
    return trees 


# hyperparameter
xgb_params = {
    'eval_metric': 'auc',
    'objective': 'binary:logistic',
    'tree_method': 'approx', 
    'verbosity': '2'
}

xgb_option = {
    'num_boost_round': 50,
    'early_stopping_round': 10
}


def convert_to_DMatrix(X_data, Y_data):
    return xgb.DMatrix(data=X_data, label=Y_data.astype(int))

def train_model(train_data, valid_data, xgb_params, xgb_option):
    xgb_model = xgb.train(
        params=xgb_params, 
        dtrain=train_data, 
        evals=[(valid_data, 'valid')],
        early_stopping_rounds=xgb_option['early_stopping_round'],
        num_boost_round=xgb_option['num_boost_round']
    )
    return xgb_model

@memory.cache
def get_xgb_model(x_num, x_cat, y):
    def get_X(split):
        X = []
        if x_num:
            X.append(x_num[split])
        if x_cat:
            X.append(x_cat[split])
        X = np.concatenate(X, 1)
        return X
        
    X_train = get_X('train')
    X_valid = get_X('val')
    Y_train = y['train']
    Y_valid = y['val']

    train_data = convert_to_DMatrix(X_train, Y_train)
    valid_data = convert_to_DMatrix(X_train, Y_train)

    xgb_model = train_model(train_data, valid_data, xgb_params, xgb_option)
    return xgb_model

    
class Model(nn.Module):
    def __init__(
        self,
        *,
        n_num_features: int,
        n_bin_features: int,
        cat_cardinalities: list[int],
        n_classes: None | int,
        backbone: dict,
        xgb_model,
    ) -> None:
        assert n_num_features or n_bin_features or cat_cardinalities
        super().__init__()

        n_features = n_num_features + n_bin_features + len(cat_cardinalities)
        xgbTree = TreeToVectorConverter(xgb_model, n_features)
        self.transform_batch = transforms.Compose([TreeToVector(xgbTree)])
        
        backbone['d_in'] = xgbTree.num_encode 
        self.backbone = lib.deep.make_module(
            **backbone,
            d_out=lib.deep.get_d_out(n_classes),
        )
        

    def forward(
        self,
        *,
        x_num: None | Tensor = None,
        x_bin: None | Tensor = None,
        x_cat: None | Tensor = None,
    ) -> Tensor:
        x = []
        if x_num is not None:
            x.append(x_num)
        if x_bin is not None:
            x.append(x_bin)
        if x_cat is not None:
            x.append(x_cat)
        
        x = torch.column_stack([x_.flatten(1, -1) for x_ in x])
        x = self.transform_batch(x)
        x = self.backbone(x)
        return x


class Config(TypedDict):
    seed: int
    data: KWArgs
    model: KWArgs
    optimizer: KWArgs
    n_lr_warmup_epochs: NotRequired[int]
    batch_size: int
    patience: int
    n_epochs: int
    gradient_clipping_norm: NotRequired[float]
    parameter_statistics: NotRequired[bool]


def main(
    config: Config, output: str | Path, *, force: bool = False
) -> None | lib.JSONDict:
    # >>> start
    print(config)
    assert set(config) >= Config.__required_keys__
    assert set(config) <= Config.__required_keys__ | Config.__optional_keys__
    if not lib.start(output, force=force):
        return None

    lib.show_config(config)  # type: ignore[code]
    output = Path(output)
    delu.random.seed(config['seed'])
    device = lib.get_device()
    report = lib.create_report(main, config)  # type: ignore[code]

    # >>> dataset
    dataset = lib.data.build_dataset(**config['data'])
    assert dataset.task.compute_n_classes() == 2
    
    if dataset.task.is_regression:
        dataset.data['y'], regression_label_stats = lib.data.standardize_labels(
            dataset.data['y']
        )
    else:
        regression_label_stats = None
    dataset = dataset.to_torch(device)
    Y_train = dataset.data['y']['train'].to(
        torch.long if dataset.task.is_multiclass else torch.float
    )

    x_num = dataset.data['x_num'] if 'x_num' in dataset.data else None
    x_cat = dataset.data['x_cat'] if 'x_cat' in dataset.data else None
    x_num = {k: v.detach().cpu().numpy() for k, v in x_num.items()} if x_num is not None else None
    x_cat = {k: v.detach().cpu().numpy() for k, v in x_cat.items()} if x_cat is not None else None
    y = {k: v.detach().cpu().numpy() for k, v in dataset.data['y'].items()}
    xgb_model = get_xgb_model(x_num, x_cat, y)

    model = Model(
        n_num_features=dataset.n_num_features,
        n_bin_features=dataset.n_bin_features,
        cat_cardinalities=dataset.compute_cat_cardinalities(),
        n_classes=dataset.task.try_compute_n_classes(),
        xgb_model=xgb_model,
        **config['model'],
    )
    report['n_parameters'] = lib.deep.get_n_parameters(model)
    logger.info(f'n_parameters = {report["n_parameters"]}')
    report['prediction_type'] = 'labels' if dataset.task.is_regression else 'logits'
    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # >>> training
    optimizer = lib.deep.make_optimizer(
        **config['optimizer'], params=lib.deep.make_parameter_groups(model)
    )
    loss_fn = lib.deep.get_loss_fn(dataset.task.type_)
    gradient_clipping_norm = config.get('gradient_clipping_norm')

    step = 0
    batch_size = config['batch_size']
    report['epoch_size'] = epoch_size = math.ceil(dataset.size('train') / batch_size)
    eval_batch_size = 32768
    chunk_size = None
    generator = torch.Generator(device).manual_seed(config['seed'])

    report['metrics'] = {'val': {'score': -math.inf}}
    if 'n_lr_warmup_epochs' in config:
        n_warmup_steps = min(10000, config['n_lr_warmup_epochs'] * epoch_size)
        n_warmup_steps = max(1, math.trunc(n_warmup_steps / epoch_size)) * epoch_size
        logger.info(f'{n_warmup_steps=}')
        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=0.01, total_iters=n_warmup_steps
        )
    else:
        lr_scheduler = None
    timer = delu.tools.Timer()
    early_stopping = delu.tools.EarlyStopping(config['patience'], mode='max')
    parameter_statistics = config.get('parameter_statistics', config['seed'] == 1)
    training_log = []
    writer = torch.utils.tensorboard.SummaryWriter(output)  # type: ignore[code]

    def apply_model(part: PartKey, idx: Tensor) -> Tensor:
        return model(
            **{
                key: dataset.data[key][part][idx]  # type: ignore[index]
                for key in ['x_num', 'x_bin', 'x_cat']
                if key in dataset  # type: ignore[index]
            }
        ).squeeze(-1)

    @torch.inference_mode()
    def evaluate(
        parts: list[PartKey], eval_batch_size: int
    ) -> tuple[dict[PartKey, Any], dict[PartKey, np.ndarray], int]:
        model.eval()
        predictions: dict[PartKey, np.ndarray] = {}
        for part in parts:
            while eval_batch_size:
                try:
                    predictions[part] = (
                        torch.cat(
                            [
                                apply_model(part, idx)
                                for idx in torch.arange(
                                    len(dataset.data['y'][part]),
                                    device=device,
                                ).split(eval_batch_size)
                            ]
                        )
                        .cpu()
                        .numpy()
                    )
                except RuntimeError as err:
                    if not lib.is_oom_exception(err):
                        raise
                    eval_batch_size //= 2
                    logger.warning(f'eval_batch_size = {eval_batch_size}')
                else:
                    break
            if not eval_batch_size:
                RuntimeError('Not enough memory even for eval_batch_size=1')
        if regression_label_stats is not None:
            predictions = {
                k: v * regression_label_stats.std + regression_label_stats.mean
                for k, v in predictions.items()
            }
        metrics = (
            dataset.task.calculate_metrics(predictions, report['prediction_type'])
            if lib.are_valid_predictions(predictions)
            else {x: {'score': -999999.0} for x in predictions}
        )
        return metrics, predictions, eval_batch_size

    def save_checkpoint() -> None:
        lib.dump_checkpoint(
            output,
            {
                'step': step,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'generator': generator.get_state(),
                'random_state': delu.random.get_state(),
                'early_stopping': early_stopping,
                'report': report,
                'timer': timer,
                'training_log': training_log,
            }
            | (
                {}
                if lr_scheduler is None
                else {'lr_scheduler': lr_scheduler.state_dict()}
            ),
        )
        lib.dump_report(output, report)
        lib.backup_output(output)

    
    print()
    timer.run()
    while config['n_epochs'] == -1 or step // epoch_size < config['n_epochs']:
        print(f'[...] {output} | {timer}')

        # >>>
        model.train()
        epoch_losses = []
        for batch_idx in tqdm(
            torch.randperm(
                len(dataset.data['y']['train']), generator=generator, device=device
            ).split(batch_size),
            desc=f'Epoch {step // epoch_size} Step {step}',
        ):
            loss, new_chunk_size = lib.deep.zero_grad_forward_backward(
                optimizer,
                lambda idx: loss_fn(apply_model('train', idx), Y_train[idx]),
                batch_idx,
                chunk_size or batch_size,
            )

            if parameter_statistics and (
                step % epoch_size == 0  # The first batch of the epoch.
                or step // epoch_size == 0  # The first epoch.
            ):
                for k, v in lib.deep.compute_parameter_stats(model).items():
                    writer.add_scalars(k, v, step, timer.elapsed())
                    del k, v

            if gradient_clipping_norm is not None:
                nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), gradient_clipping_norm
                )
            optimizer.step()

            if lr_scheduler is not None:
                lr_scheduler.step()
            step += 1
            epoch_losses.append(loss.detach())
            if new_chunk_size and new_chunk_size < (chunk_size or batch_size):
                chunk_size = new_chunk_size
                logger.warning(f'chunk_size = {chunk_size}')

        epoch_losses = torch.stack(epoch_losses).tolist()
        mean_loss = statistics.mean(epoch_losses)
        metrics, predictions, eval_batch_size = evaluate(
            ['val', 'test'], eval_batch_size
        )

        training_log.append(
            {'epoch-losses': epoch_losses, 'metrics': metrics, 'time': timer.elapsed()}
        )
        lib.print_metrics(mean_loss, metrics)
        writer.add_scalars('loss', {'train': mean_loss}, step, timer.elapsed())
        for part in metrics:
            writer.add_scalars(
                'score', {part: metrics[part]['score']}, step, timer.elapsed()
            )

        if metrics['val']['score'] > report['metrics']['val']['score']:
            print('🌸 New best epoch! 🌸')
            report['best_step'] = step
            report['metrics'] = metrics
            save_checkpoint()
            lib.dump_predictions(output, predictions)

        early_stopping.update(metrics['val']['score'])
        if early_stopping.should_stop() or not lib.are_valid_predictions(predictions):
            break

        print()
    report['time'] = str(timer)

    # >>> finish
    model.load_state_dict(lib.load_checkpoint(output)['model'])
    report['metrics'], predictions, _ = evaluate(
        ['train', 'val', 'test'], eval_batch_size
    )
    report['chunk_size'] = chunk_size
    report['eval_batch_size'] = eval_batch_size
    lib.dump_predictions(output, predictions)
    lib.dump_summary(output, lib.summarize(report))
    save_checkpoint()
    lib.finish(output, report)
    return report


if __name__ == '__main__':
    lib.configure_libraries()
    lib.run_MainFunction_cli(main)
