# Feed-forward network[s].
# - Add rec-sup pretraining objective
# - better logging
# - add feature whitening

import math
import statistics
from collections import defaultdict
from pathlib import Path
from typing import Any

import delu
import numpy as np
import rtdl_num_embeddings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.tensorboard
from loguru import logger
from torch import Tensor
from tqdm import tqdm
from typing_extensions import Callable, NotRequired, TypedDict
from sklearn.decomposition import PCA

import lib
from lib import KWArgs, PartKey
from lib.distributed_shampoo.distributed_shampoo import DistributedShampoo

EvalOut = tuple[dict[PartKey, Any], dict[PartKey, np.ndarray], int]

# This fixes Shampoo fails
torch.backends.cuda.preferred_linalg_library('magma')

class Model(nn.Module):
    def __init__(
        self,
        *,
        n_num_features: int,
        n_bin_features: int,
        cat_cardinalities: list[int],
        n_classes: None | int,
        bins: None | list[Tensor],
        num_embeddings: None | dict = None,
        backbone: dict,
        aug_type: str,
        
    ) -> None:
        assert n_num_features or n_bin_features or cat_cardinalities
        super().__init__()

        self.flat = backbone['type'] != 'FTTransformerBackbone'
        self.aug_type = aug_type
        self.aug_p = 0.0
        self.pretrain = True
        self.cat_cardinalities = cat_cardinalities

        if not self.flat and n_bin_features > 0:
            self.m_bin = lib.deep.make_module("LinearEmbeddings", n_bin_features, backbone['d_block'])
        else:
            self.m_bin = None

        if num_embeddings is None:
            assert bins is None
            self.m_num = None
            d_num = n_num_features
        else:
            if not self.flat:
                num_embeddings['d_embedding'] = backbone['d_block']

            assert n_num_features > 0
            if num_embeddings['type'] in (
                rtdl_num_embeddings.PiecewiseLinearEmbeddings.__name__,
                rtdl_num_embeddings.PiecewiseLinearEncoding.__name__,
            ):
                assert bins is not None
                self.m_num = lib.deep.make_module(**num_embeddings, bins=bins)
                d_num = (
                    sum(len(x) - 1 for x in bins)
                    if num_embeddings['type'].startswith(
                        rtdl_num_embeddings.PiecewiseLinearEncoding.__name__
                    )
                    else n_num_features * num_embeddings['d_embedding']
                )
            else:
                assert bins is None
                self.m_num = lib.deep.make_module(
                    **num_embeddings, n_features=n_num_features
                )
                d_num = n_num_features * num_embeddings['d_embedding']

        if backbone['type'] in ['DCNv2', 'FTTransformerBackbone']:
            d_cat_embedding = backbone.pop('d_cat_embedding') if self.flat else backbone['d_block']
            self.m_cat = (
                lib.deep.CategoricalEmbeddings1d(cat_cardinalities, d_cat_embedding)
                if cat_cardinalities
                else None
            )
            d_cat = len(cat_cardinalities) * d_cat_embedding
        else:
            self.m_cat = (
                lib.deep.OneHotEncoding0d(cat_cardinalities)
                if cat_cardinalities
                else None
            )
            d_cat = sum(cat_cardinalities)


        if self.flat:
            d_ind = n_num_features if 'indicator' in self.aug_type else 0
            backbone['d_in'] = d_num + n_bin_features + d_cat + d_ind
        else:
            self.cls_embedding = lib.deep.CLSEmbedding(backbone['d_block'])

        self.backbone = lib.deep.make_module(
            **backbone,
            d_out=None,
            
        )
        self.supervised_out = nn.Linear(backbone['d_block'], lib.deep.get_d_out(n_classes))
        self.reconstruction_out = nn.Linear(backbone['d_block'], n_num_features + sum(cat_cardinalities))

    def forward(
        self,
        *,
        x_num: None | Tensor = None,
        x_bin: None | Tensor = None,
        x_cat: None | Tensor = None,
    ) -> Tensor:
        x = []
        if x_num is not None:
            mask = torch.empty_like(x_num).bernoulli_(p=self.aug_p).bool()

            if self.aug_type == 'indicator-v1':
                x_num[mask] = 0
                x_ind = (mask * 2 - 1).float()
                x.append(x_ind)
            elif self.aug_type == 'indicator-v2':
                x_num[mask] = 0
                x_ind = mask.float()
                x.append(x_ind)
            elif self.aug_type == 'dropout':
                x_num[mask] = 0
            elif self.aug_type == 'dropout-v2':
                x_num[mask] = 1
            elif self.aug_type == 'dropout-v3':
                values = torch.empty_like(x_num).bernoulli_(p=0.5).bool()
                values = (2 * values - 1).float()
                x_num[mask] = values[mask]
            elif self.aug_type == 'dropout-v4':
                values = torch.empty_like(x_num).bernoulli_(p=0.5).bool()
                values = (2 * values - 1).float() * (1-self.aug_p)
                x_num[mask] = values[mask]

            
            x.append(x_num if self.m_num is None else self.m_num(x_num))
        if x_bin is not None:
            x.append(x_bin if self.m_bin is None else self.m_bin(x_bin))
        if x_cat is None:
            assert self.m_cat is None
        else:
            assert self.m_cat is not None
            mask = torch.empty_like(x_cat).bernoulli_(p=self.aug_p).bool()
            cat_card = torch.tensor(self.cat_cardinalities, device=x_cat.device)[None].repeat(x_cat.shape[0], 1).long()
            x_cat[mask] = cat_card[mask]
            x.append(self.m_cat(x_cat))

        if self.flat:
            x = torch.column_stack([x_.flatten(1, -1) for x_ in x])
        else:
            x = torch.cat([self.cls_embedding(x[0].shape[:1])] + x, dim=1)

        x = self.backbone(x)

        if self.pretrain:
            return torch.cat([
                self.reconstruction_out(x),
                self.supervised_out(x),
            ], dim=-1)
        else:
            return self.supervised_out(x)


class Config(TypedDict):
    seed: int
    data: KWArgs
    bins: NotRequired[KWArgs]
    model: KWArgs
    optimizer: KWArgs
    n_lr_warmup_epochs: NotRequired[int]
    aug_p: float
    alpha: float
    do_zca: bool
    batch_size: int
    batch_repeats: NotRequired[int]
    patience: int
    n_epochs: int
    n_steps_pretrain: int
    gradient_clipping_norm: NotRequired[float]
    parameter_statistics: NotRequired[bool]
    patience_pretrain: NotRequired[int]
    amp: NotRequired[bool]  # Automatic mixed precision in bfloat16.

def main(
    config: Config, output: str | Path, *, force: bool = False
) -> None | lib.JSONDict:
    # >>> start
    assert set(config) >= Config.__required_keys__
    assert set(config) <= Config.__required_keys__ | Config.__optional_keys__
    if not lib.start(output, force=force):
        return None

    lib.show_config(config)  # type: ignore[code]
    output = Path(output)
    delu.random.seed(config['seed'])
    device = lib.get_device()
    report = lib.create_report(config)  # type: ignore[code]

    # >>> dataset
    dataset = lib.data.build_dataset(**config['data'])
    if dataset.task.is_regression:
        dataset.data['y'], regression_label_stats = lib.data.standardize_labels(
            dataset.data['y']
        )
    else:
        regression_label_stats = None


    dataset = dataset.to_torch(device)
    Y_train = dataset.data['y']['train'].to(
        torch.long if dataset.task.is_multiclass else torch.float
    )

    # >>> decorrelate the features

    if config['do_zca']:
        sigma = torch.cov(dataset.data['x_num']['train'].T)
        U,S,_ = torch.linalg.svd(sigma)
        zca_matrix = U @ (torch.diag(1 / torch.sqrt(S + 1e-5)) @ U.T)
    else:
        zca_matrix = torch.eye(dataset.n_num_features, device=device)

    # for p in dataset.data['x_num']:
    #     dataset.data['x_num'][p] = dataset.data['x_num'][p] @ zca_matrix

    # >>> model
    if 'bins' in config:
        compute_bins_kwargs = (
            {
                'y': Y_train.to(
                    torch.long if dataset.task.is_classification else torch.float
                ),
                'regression': dataset.task.is_regression,
                'verbose': True,
            }
            if 'tree_kwargs' in config['bins']
            else {}
        )
        bin_edges = rtdl_num_embeddings.compute_bins(
            dataset['x_num']['train'], **config['bins'], **compute_bins_kwargs
        )
        logger.info(f'Bin counts: {[len(x) - 1 for x in bin_edges]}')
    else:
        bin_edges = None

    model = Model(
        n_num_features=dataset.n_num_features,
        n_bin_features=dataset.n_bin_features,
        cat_cardinalities=dataset.compute_cat_cardinalities(),
        n_classes=dataset.task.try_compute_n_classes(),
        **config['model'],
        bins=bin_edges,
    )

    report['n_parameters'] = lib.deep.get_n_parameters(model)
    logger.info(f'n_parameters = {report["n_parameters"]}')
    report['prediction_type'] = 'labels' if dataset.task.is_regression else 'logits'
    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # >>> training
    optimizer = lib.deep.make_optimizer(
        **config['optimizer'], params=lib.deep.make_parameter_groups(model)
    )
    loss_fn = lib.deep.get_loss_fn(dataset.task.type_)
    gradient_clipping_norm = config.get('gradient_clipping_norm')

    batch_size = config['batch_size']
    report['epoch_size'] = epoch_size = math.ceil(dataset.size('train') / batch_size)
    eval_batch_size = 32768
    chunk_size = None
    generator = torch.Generator(device).manual_seed(config['seed'])

    report['metrics'] = {'val': {'score': -math.inf}}
    if 'n_lr_warmup_epochs' in config:
        n_warmup_steps = min(10000, config['n_lr_warmup_epochs'] * epoch_size)
        n_warmup_steps = max(1, math.trunc(n_warmup_steps / epoch_size)) * epoch_size
        logger.info(f'{n_warmup_steps=}')
        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=0.01, total_iters=n_warmup_steps
        )
    else:
        lr_scheduler = None

    timer = delu.tools.Timer()
    parameter_statistics = config.get('parameter_statistics', config['seed'] == 1)
    training_log = []
    writer = torch.utils.tensorboard.SummaryWriter(output)  # type: ignore[code]

    amp_enabled = (
        config.get('amp', False)
        and device.type == 'cuda'
        and torch.cuda.is_bf16_supported()
    )
    logger.info(f'AMP enabled: {amp_enabled}')

    @torch.autocast(  # type: ignore[code]
        device.type, enabled=amp_enabled, dtype=torch.bfloat16 if amp_enabled else None
    )
    def apply_model(part: PartKey, idx: Tensor, aug=False) -> Tensor:

        def permute(x):
            "shuffle p of inputs (add noise)"
            x_batch = x[idx]

            if model.aug_type == 'shuffle':
                perm_idx = torch.randint_like(x_batch, dataset.size(part), dtype=torch.int64)            
                mask = torch.empty_like(x_batch).bernoulli_(p=model.aug_p).bool()
                x_perm = x.gather(0, perm_idx)
                x_batch[mask] = x_perm[mask]
            if model.aug_type == 'cutmix':
                perm_idx = torch.randint_like(idx, dataset.size(part), dtype=torch.int64)
                mask = torch.empty_like(x_batch).bernoulli_(p=model.aug_p).bool()
                x_perm = x[perm_idx]
                x_batch[mask] = x_perm[mask]

            return x_batch

        if aug:
            model.aug_p = config['aug_p']  # type: ignore[code]
        else:
            model.aug_p = 0.0 # type: ignore[code]

        model_input = {
            key: dataset.data[key][part][idx] if key != 'x_num' or model.aug_type != 'shuffle'  # type: ignore[code]
            else permute(dataset.data[key][part])
            for key in ['x_num', 'x_bin', 'x_cat']
            if key in dataset  # type: ignore[index]
        }

        return model(**model_input).squeeze(-1).float()

    
    @torch.inference_mode()
    def evaluate_finetune(
        parts: list[PartKey], eval_batch_size: int
    ) -> EvalOut:
        model.eval()
        predictions: dict[PartKey, np.ndarray] = {}
        for part in parts:
            while eval_batch_size:
                try:
                    predictions[part] = (
                        torch.cat(
                            [
                                apply_model(part, idx)
                                for idx in torch.arange(
                                    len(dataset.data['y'][part]),
                                    device=device,
                                ).split(eval_batch_size)
                            ]
                        )
                        .cpu()
                        .numpy()
                    )
                except RuntimeError as err:
                    if not lib.is_oom_exception(err):
                        raise
                    eval_batch_size //= 2
                    logger.warning(f'eval_batch_size = {eval_batch_size}')
                else:
                    break
            if not eval_batch_size:
                RuntimeError('Not enough memory even for eval_batch_size=1')
        if regression_label_stats is not None:
            predictions = {
                k: v * regression_label_stats.std + regression_label_stats.mean
                for k, v in predictions.items()
            }
        metrics = (
            dataset.task.calculate_metrics(predictions, report['prediction_type'])
            if lib.are_valid_predictions(predictions)
            else {x: {'score': -999999.0} for x in predictions}
        )
        return metrics, predictions, eval_batch_size

    def train_loop(
        *,
        step_fn: Callable[[Tensor], Tensor],
        eval_fn: Callable[..., tuple],
        n_steps: int,
        patience: int,
        report_key: str,
        chunk_size=None, eval_batch_size=eval_batch_size
    ):
        def save_checkpoint(step) -> None:
            lib.dump_checkpoint(
                output,
                {
                    'step': step,
                    'model': model.state_dict(),
                    'optimizer': (
                        optimizer.distributed_state_dict(model.named_parameters())
                        if isinstance(optimizer, DistributedShampoo)
                        else optimizer.state_dict()
                    ),
                    'generator': generator.get_state(),
                    'random_state': delu.random.get_state(),
                    'early_stopping': early_stopping,
                    'report': report,
                    'timer': timer,
                    'training_log': training_log,
                }
                | (
                    {}
                    if lr_scheduler is None
                    else {'lr_scheduler': lr_scheduler.state_dict()}
                ),
            )
            lib.dump_report(output, report)
            lib.backup_output(output)

        step = 0
        early_stopping = delu.tools.EarlyStopping(patience, mode='max')
        report[report_key] = {'metrics': {'val': {'score': -math.inf}}}

        while n_steps == -1 or step < n_steps:
            print(f'[...] {output} | {timer}')

            # >>>
            model.train()
            model.pretrain = report_key == "pretrain"  # type: ignore[key]
            epoch_losses = []
            logs_train = defaultdict(list)

            for batch_idx in tqdm(
                torch.randperm(
                    len(dataset.data['y']['train']), generator=generator, device=device
                ).split(batch_size),
                desc=f'Epoch {step // epoch_size} Step {step}',
            ):
                loss, new_chunk_size = lib.deep.zero_grad_forward_backward(
                    optimizer,
                    step_fn,
                    batch_idx,
                    chunk_size or batch_size,
                )

                for k, v in log_dict.items():
                    logs_train[k].append(v)

                if parameter_statistics and (
                    step % epoch_size == 0  # The first batch of the epoch.
                    or step // epoch_size == 0  # The first epoch.
                ):
                    for k, v in lib.deep.compute_parameter_stats(model).items():
                        writer.add_scalars(f'{report_key}/{k}', v, step, timer.elapsed())
                        del k, v

                if gradient_clipping_norm is not None:
                    nn.utils.clip_grad.clip_grad_norm_(
                        model.parameters(), gradient_clipping_norm
                    )
                optimizer.step()

                if lr_scheduler is not None:
                    lr_scheduler.step()
                step += 1
                epoch_losses.append(loss.detach())
                if new_chunk_size and new_chunk_size < (chunk_size or batch_size):
                    chunk_size = new_chunk_size
                    logger.warning(f'chunk_size = {chunk_size}')

            epoch_losses = torch.stack(epoch_losses).tolist()
            mean_loss = statistics.mean(epoch_losses)
            
            model.pretrain = False   # type: ignore[key]
            metrics, predictions, eval_batch_size = eval_fn(
                ['val', 'test'], eval_batch_size
            )
            metrics['train'] = {}
            for k, v in logs_train.items():
                metrics['train'][k] = np.mean(v).item()

            training_log.append(
                {'epoch-losses': epoch_losses, 'metrics': metrics, 'time': timer.elapsed()}
            )
            lib.print_metrics(mean_loss, metrics)
            writer.add_scalars(f'{report_key}/loss', {'train': mean_loss}, step, timer.elapsed())
            for part in metrics:
                for k in metrics[part].keys():
                    if k not in ['score', 'loss_sup', 'loss_num', 'loss_cat']:
                        continue
                    writer.add_scalars(f'{report_key}/{k}', {part: metrics[part][k]}, step, timer.elapsed())

            if metrics['val']['score'] > report[report_key]['metrics']['val']['score']:
                print('🌸 New best epoch! 🌸')
                report[report_key]['best_step'] = step
                report[report_key]['metrics'] = metrics
                save_checkpoint(step)
                lib.dump_predictions(output, predictions)

            early_stopping.update(metrics['val']['score'])
            if early_stopping.should_stop() or not lib.are_valid_predictions(predictions):
                break

            print()
        return chunk_size, eval_batch_size

    output_sizes = {
        'x_num': dataset.n_num_features,
    } | {
        f'x_cat_{i}': v for i, v in enumerate(dataset.compute_cat_cardinalities())
    } | {
        'y': lib.deep.get_d_out(dataset.task.try_compute_n_classes())
    }

    def pretrain_loss_fn(model_out: Tensor, targets: dict[str, Tensor], log_dict=None):
        model_out_dict = {
            k: v
            for k,v in zip(
                output_sizes.keys(),
                model_out.split(tuple(output_sizes.values()), dim=1)
            )
            if len(v)
        }

        loss_sup = loss_fn(
            model_out_dict['y'].squeeze(),
            targets['y'].to(
                torch.long if dataset.task.is_multiclass else torch.float
            )
        )
        num_target = targets['x_num'] @ zca_matrix if config['do_zca'] else targets['x_num']

        loss_num = (
            F.mse_loss(
                model_out_dict['x_num'],
                num_target,
            ) if 'x_num' in model_out_dict
            else torch.tensor(0, device=device)
        )

        if 'x_cat' in model_out_dict:
            loss_cat = torch.tensor(0, device=device)
            for i in range(targets['x_cat'].shape[1]):
                loss_cat += F.cross_entropy(model_out_dict[f'x_cat_{i}'], targets['x_cat'][:, i].long())
            loss_cat = loss_cat / targets['x_cat'].shape[1]
        else:
            loss_cat = torch.tensor(0, device=device)

        if log_dict is not None:
            log_dict['loss_sup'] = loss_sup.item()
            log_dict['loss_num'] = loss_num.item()
            log_dict['loss_cat'] = loss_cat.item()

        return config['alpha'] * loss_sup + (1 - config['alpha']) * (loss_num + loss_cat)

    def pretrain_step(idx):
        part_data = {k: v[idx] for k, v in dataset.part_data('train').items()}
        return pretrain_loss_fn(apply_model('train', idx, aug=True), part_data, log_dict=log_dict)

    def finetune_step(idx):
        return loss_fn(apply_model('train', idx), Y_train[idx])

    # Log everything from here
    log_dict = {}

    # >>> pretrain
    print('Pretraining')
    timer.run()
    pretrain_patience = config.get('patience_pretrain', config['patience'])
    chunk_size, eval_batch_size = train_loop(
        step_fn=pretrain_step, eval_fn=evaluate_finetune,
        n_steps=config['n_steps_pretrain'],
        patience=pretrain_patience,
        report_key="pretrain",
        chunk_size=chunk_size
    )

    log_dict = {}
    # >>> finetune
    model.pretrain = False  # type: ignore[code]
    print('Finetuning')
    try:
        model.load_state_dict(lib.load_checkpoint(output)['model'])
    except Exception as e:
        print('Failed loading checkpoint')
        print(e)

    chunk_size, eval_batch_size = train_loop(
        step_fn=finetune_step, eval_fn=evaluate_finetune,
        n_steps=config['n_epochs'],
        patience=config['patience'],
        report_key="finetune",
        chunk_size=chunk_size
    )
    report['time'] = str(timer)

    # >>> finish
    model.load_state_dict(lib.load_checkpoint(output)['model'])
    report['metrics'], predictions, _ = evaluate_finetune(
        ['train', 'val', 'test'], eval_batch_size
    )
    report['chunk_size'] = chunk_size
    report['eval_batch_size'] = eval_batch_size
    lib.dump_predictions(output, predictions)
    lib.dump_summary(output, lib.summarize(report))
    lib.finish(output, report)
    return report


if __name__ == '__main__':
    lib.configure_libraries()
    lib.run_MainFunction_cli(main)
