from argparse import ArgumentParser
from pathlib import Path
import random
import shutil

import numpy as np
import torch as pt

from object_centric_bench.datum import DataLoader
from object_centric_bench.learn import MetricWrap
from object_centric_bench.model import ModelWrap2
from object_centric_bench.utils import Config, build_from_config


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    pt.manual_seed(seed)


def main(args):
    print(args)

    cfg_file = Path(args.cfg_file)
    data_path = Path(args.data_dir)
    ckpt_file = Path(args.ckpt_file) if args.ckpt_file else None

    assert cfg_file.name.endswith(".py")
    assert cfg_file.is_file()
    cfg_name = cfg_file.name.split(".")[0]
    cfg = Config.fromfile(args.cfg_file)

    save_path = Path(args.save_dir) / cfg_name / str(args.seed)
    save_path.mkdir(parents=True, exist_ok=True)
    shutil.copy(args.cfg_file, save_path.parent)

    set_seed(args.seed)  # for reproducibility
    pt.backends.cudnn.benchmark = False  # XXX True: 20% faster but stochastic
    pt.backends.cudnn.deterministic = True  # for cuda devices
    pt.use_deterministic_algorithms(True, warn_only=True)  # for all devices

    ## datum init

    work_init_fn = lambda _: set_seed(args.seed)  # for reproducibility
    rng = pt.Generator()
    rng.manual_seed(args.seed)

    cfg.dataset_t.base_dir = cfg.dataset_v.base_dir = data_path

    dataset_t = build_from_config(cfg.dataset_t)
    dataload_t = DataLoader(
        dataset_t,
        cfg.batch_size_t,  # TODO XXX TODO XXX TODO XXX TODO XXX // 2
        shuffle=True,
        num_workers=cfg.num_work,
        pin_memory=True,
        worker_init_fn=work_init_fn,
        generator=rng,
    )
    dataset_v = build_from_config(cfg.dataset_v)
    dataload_v = DataLoader(
        dataset_v,
        cfg.batch_size_v,
        num_workers=cfg.num_work,
        shuffle=False,
        pin_memory=True,
        worker_init_fn=work_init_fn,
        generator=rng,
    )

    ## model init

    model = build_from_config(cfg.model)
    print(model)

    model = ModelWrap2(model, cfg.model_imap, cfg.model_omap)
    model = model.cuda()

    if ckpt_file:
        model.load(ckpt_file, cfg.ckpt_map)

    if cfg.freez:
        model.freez(cfg.freez)

    # model.compile()  # in ModelWrap2, MetricWrap

    ## learn init

    if cfg.param_groups:
        cfg.optimiz.params = model.group_params(cfg.param_groups)
    else:
        cfg.optimiz.params = model.parameters()
    optimiz = build_from_config(cfg.optimiz)
    optimiz.gscale = build_from_config(cfg.gscale)
    optimiz.gclip = build_from_config(cfg.gclip)

    loss_fn = MetricWrap(**build_from_config(cfg.loss_fn))
    metric_fn = MetricWrap(detach=True, **build_from_config(cfg.metric_fn))

    for cb in cfg.callback_t + cfg.callback_v:
        if cb.type == "AverageLog":
            cb.log_file = f"{save_path}.txt"
        elif cb.type == "SaveModel":
            cb.save_dir = save_path
    callback_t = build_from_config(cfg.callback_t)
    callback_v = build_from_config(cfg.callback_v)

    ## learn loop

    cfg.loop.dataset_t = dataload_t
    cfg.loop.dataset_v = dataload_v
    cfg.loop.model = model
    cfg.loop.optimiz = optimiz
    cfg.loop.loss_fn = loss_fn
    cfg.loop.metric_fn = metric_fn
    cfg.loop.callback_t = callback_t
    cfg.loop.callback_v = callback_v

    loop = build_from_config(cfg.loop)
    loop()


def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        # default=np.random.randint(2**32),
    )
    parser.add_argument(
        "--cfg_file",
        type=str,
        default="config-ms-c4/slotdiffuz_r_vqvae-clevrtex-ms.py",
    )
    parser.add_argument(
        "--data_dir", type=str, default="/media/GeneralZ/Storage/Static/datasets"
    )
    parser.add_argument("--save_dir", type=str, default="save")
    parser.add_argument(
        "--ckpt_file",
        type=str,
        # default="archive-ms-c256/vqvae-clevrtex/best.pth",
    )
    return parser.parse_args()


if __name__ == "__main__":
    # with pt.autograd.detect_anomaly():  # detect NaN
    pt._dynamo.config.suppress_errors = True  # TODO XXX one_hot, interplolate
    main(parse_args())
