import argparse
import numpy as np
import torch
import os
import re
import sys

import torch.distributed
sys.path.append('.')

import deepspeed

from src.configs import oxe_policy_embedding_config, deepspeed_config
from src.models import PolicyEmbedModel
from src.utils.oxe_data import TrajDataset, DistributedWeightedRandomSampler
from torch.utils.data import DataLoader
from src.utils.logger import Logger, make_log_dirs
from src.utils.utils import set_seed, DictToObject, get_optimizer_params
from src.utils.deepspeed_trainer import DeepSpeedTrainer


def dataloader_to_step(dataloader, step):
    for _ in range(step):
        try:
            batch = next(dataloader)
        except StopIteration:
            break
        

def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', type=str, default='PolicyEmbed')
    parser.add_argument('--log', type=str, default='OXE')
    parser.add_argument('--data-path', type=str, default='/infinite/common/rlds-action-dataset')
    parser.add_argument('--indices-path', type=str, default='/infinite/common/rlds-action-dataset/indices.json')
    parser.add_argument('--ckpt-path', type=str, default=None)

    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--class-size', type=int, default=16)
    parser.add_argument('--cat-size', type=int, default=16)
    parser.add_argument('--info', type=str, default=None)

    known_args, _ = parser.parse_known_args()
    for arg_key, default_value in oxe_policy_embedding_config.items():
        parser.add_argument(f'--{arg_key}', default=default_value, type=type(default_value))

    args = parser.parse_args()

    return args


def main(config=get_config()):
    # setup
    deepspeed.init_distributed()
    local_rank = int(os.environ["LOCAL_RANK"])
    set_seed(config.seed)

    # dataset
    dataset = TrajDataset(config.data_path, config.seq_len, config.indices_path)
    generator = torch.Generator().manual_seed(config.seed)
    sampler = DistributedWeightedRandomSampler(dataset.weights, len(dataset), generator=generator, replacement=True)
    dataloader = DataLoader(dataset, batch_size=config.batch_size, sampler=sampler, num_workers=4)
    
    # logger
    output_config = {
        "consoleout_backup": "stdout",
        "training_progress": "csv",
        "tb": "tensorboard"
    }
    if local_rank > 0:
        torch.distributed.barrier()
        logdir_name = os.path.dirname(make_log_dirs(config.log, config.model_name, config.seed, vars(config)))
        seed_folders = os.listdir(logdir_name)
        timestamp_regex = re.compile(r'seed_\d+&timestamp_(\d{2})-(\d{4})')
        sorted_folders = sorted(seed_folders, key=lambda x: tuple(map(int, timestamp_regex.search(x).groups())), reverse=True)
        log_path = os.path.join(logdir_name, sorted_folders[0])
        logger = Logger(log_path, output_config)
    if local_rank == 0:
        log_path = make_log_dirs(config.log, config.model_name, config.seed, vars(config))
        logger = Logger(log_path, output_config)
        torch.distributed.barrier()

    config.policy.update({'action_dim': dataset.action_dim})
    config.posterior.update({'cat_size': config.cat_size, 'class_size': config.class_size})

    # init video tokenizer
    embed_model = PolicyEmbedModel(
        DictToObject(config.posterior),
        DictToObject(config.prior),
        DictToObject(config.policy),
        DictToObject(config.kl),
        config.img_size, config.patch_size, config.seq_len,
    )
    optim_params = get_optimizer_params(embed_model, weight_decay=0)

    model_engine, optimizer, _, __ = deepspeed.initialize(
        model=embed_model,
        model_parameters=optim_params,
        config=deepspeed_config,
    )
    
    if config.ckpt_path is not None:
        _, client_sd = model_engine.load_checkpoint(config.ckpt_path)
        step = client_sd['step']

    trainer = DeepSpeedTrainer(model_engine, dataloader, logger)
    trainer.train(train_steps=config.n_train_steps)


if __name__ == "__main__":
    main()