import argparse
import os
import deepspeed
import torch
import re

from src.configs import video_tokenizer_config, deepspeed_config
from src.models import VideoTokenizer
from src.utils.data import TrajDataset
from src.utils.logger import Logger, make_log_dirs
from src.utils.utils import set_seed, DictToObject, get_optimizer_params
from src.utils.deepspeed_trainer import DeepSpeedTrainer


def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', type=str, default='VideoTokenizer')
    parser.add_argument('--log', type=str, default='MetaWorld')
    parser.add_argument('--train-data-path', type=str, default='/infinite/common/metaworld-64/train')
    parser.add_argument('--ckpt-path', type=str, default=None)

    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--info', type=str, default=None)

    for arg_key, default_value in video_tokenizer_config.items():
        parser.add_argument(f'--{arg_key}', default=default_value, type=type(default_value))
    
    args = parser.parse_args()

    return args


def main(config=get_config()):
    # setup
    deepspeed.init_distributed()
    local_rank = int(os.environ["LOCAL_RANK"])
    set_seed(config.seed)

    # dataset
    dataset = TrajDataset(config.train_data_path, video_tokenizer_config['seq_len'], os.path.join(config.train_data_path, 'train_indices.json'), device=config.device)

    # logger
    output_config = {
        "consoleout_backup": "stdout",
        "training_progress": "csv",
        "tb": "tensorboard"
    }
    if local_rank > 0:
        torch.distributed.barrier()
        logdir_name = os.path.dirname(make_log_dirs(config.log, config.model_name, config.seed, vars(config)))
        seed_folders = os.listdir(logdir_name)
        timestamp_regex = re.compile(r'seed_\d+&timestamp_(\d{2})-(\d{4})')
        sorted_folders = sorted(seed_folders, key=lambda x: tuple(map(int, timestamp_regex.search(x).groups())), reverse=True)
        log_path = os.path.join(logdir_name, sorted_folders[0])
        logger = Logger(log_path, output_config)
    if local_rank == 0:
        log_path = make_log_dirs(config.log, config.model_name, config.seed, vars(config))
        logger = Logger(log_path, output_config)
        torch.distributed.barrier()
    
    # init video tokenizer
    video_tokenizer = VideoTokenizer(
        DictToObject(config.encoder),
        DictToObject(config.decoder),
        DictToObject(config.codebook),
        config.img_size, config.patch_size, config.seq_len
    )
    optim_params = get_optimizer_params(video_tokenizer, weight_decay=0)

    model_engine, optimizer, dataloader, __ = deepspeed.initialize(
        model=video_tokenizer,
        model_parameters=optim_params,
        training_data=dataset,
        config=deepspeed_config,
    )

    if config.ckpt_path is not None:
        _, client_sd = model_engine.load_checkpoint(config.ckpt_path)
        step = client_sd['step']

    trainer = DeepSpeedTrainer(model_engine, dataloader, logger)
    trainer.train(train_steps=100000)


if __name__ == "__main__":
    main()