import argparse
import numpy as np
import torch
import os
import re
import sys

import torch.distributed
sys.path.append('.')

import deepspeed

from src.configs import oxe_video_tokenizer_config, oxe_dynamic_model_config, dynamic_deepspeed_config
from src.models import VideoTokenizer, PolicyEmbedModel, DynamicModel
from src.utils.oxe_data import TrajDataset, DistributedWeightedRandomSampler
from torch.utils.data import DataLoader
from src.utils.logger import Logger, make_log_dirs
from src.utils.utils import set_seed, DictToObject, get_optimizer_params
from src.utils.deepspeed_trainer import DeepSpeedTrainer


def dataloader_to_step(dataloader, step):
    for _ in range(step):
        try:
            batch = next(dataloader)
        except StopIteration:
            break
        

def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', type=str, default='Dyanmic(256_16)')
    parser.add_argument('--log', type=str, default='OXE')
    parser.add_argument('--data-path', type=str, default='/infinite/common/rlds-action-dataset-new')
    parser.add_argument('--indices-path', type=str, default='/infinite/common/rlds-action-dataset/indices.json')
    parser.add_argument('--video-tokenizer-path', type=str, default='models/oxe_tokenizer.pth')
    parser.add_argument('--ckpt-path', type=str, default=None)
    parser.add_argument('--embed-type', type=int, default=-1)

    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--info', type=str, default=None)

    known_args, _ = parser.parse_known_args()
    for arg_key, default_value in oxe_dynamic_model_config.items():
        parser.add_argument(f'--{arg_key}', default=default_value, type=type(default_value))

    args = parser.parse_args()

    return args


def main(config=get_config()):
    # setup
    deepspeed.init_distributed()
    local_rank = int(os.environ["LOCAL_RANK"])
    set_seed(config.seed)

    # dataset
    dataset = TrajDataset(config.data_path, config.seq_len, config.indices_path)
    generator = torch.Generator().manual_seed(config.seed)
    sampler = DistributedWeightedRandomSampler(dataset.weights, len(dataset), generator=generator, replacement=True)
    dataloader = DataLoader(dataset, batch_size=config.batch_size, sampler=sampler, num_workers=4)
    
    # logger
    output_config = {
        "consoleout_backup": "stdout",
        "training_progress": "csv",
        "tb": "tensorboard"
    }
    
    if local_rank > 0:
        torch.distributed.barrier()
        logdir_name = os.path.dirname(make_log_dirs(config.log, config.model_name, config.seed, vars(config), mkdir=False))
        seed_folders = os.listdir(logdir_name)
        timestamp_regex = re.compile(r'seed_\d+&timestamp_(\d{2})-(\d{4})-(\d{6})')
        sorted_folders = sorted(seed_folders, key=lambda x: tuple(map(int, timestamp_regex.search(x).groups())), reverse=True)
        log_path = os.path.join(logdir_name, sorted_folders[0])
        logger = Logger(log_path, output_config)
    if local_rank == 0:
        log_path = make_log_dirs(config.log, config.model_name, config.seed, vars(config))
        logger = Logger(log_path, output_config)
        torch.distributed.barrier()

    # init video tokenizer
    video_tokenizer = VideoTokenizer(
        DictToObject(oxe_video_tokenizer_config['encoder']),
        DictToObject(oxe_video_tokenizer_config['decoder']),
        DictToObject(oxe_video_tokenizer_config['codebook']),
        oxe_video_tokenizer_config['img_size'], oxe_video_tokenizer_config['patch_size'], oxe_video_tokenizer_config['seq_len']
    )
    video_tokenizer.load_state_dict(torch.load(config.video_tokenizer_path))
    video_tokenizer.eval()

    config.decoder.update({'action_dim': dataset.action_dim})
    config.policy.update({'action_dim': dataset.action_dim})
    
    if config.embed_type > 0:
        config.posterior.update({'cat_size': config.embed_type, 'class_size': config.embed_type})
        embed_model = PolicyEmbedModel(
            DictToObject(config.posterior),
            DictToObject(config.prior),
            DictToObject(config.policy),
            DictToObject(config.kl),
            config.img_size, config.patch_size, config.seq_len
        )
        embed_model.load_state_dict(torch.load(os.path.join('models', f'oxe_embed_{config.embed_type}.pth')))
        embed_model.eval()
    else:
        embed_model = None

    dynamic_model = DynamicModel(
        video_tokenizer,
        embed_model,
        DictToObject(config.decoder)
    )
    
    # torch.distributed.barrier()
    optim_params = get_optimizer_params(dynamic_model, weight_decay=1e-4)

    model_engine, optimizer, _, __ = deepspeed.initialize(
        model=dynamic_model,
        model_parameters=optim_params,
        config=dynamic_deepspeed_config,
    )
    
    if config.ckpt_path is not None:
        _, client_sd = model_engine.load_checkpoint(config.ckpt_path)
        step = client_sd['step']

    trainer = DeepSpeedTrainer(model_engine, dataloader, logger)
    trainer.train(train_steps=config.n_train_steps)


if __name__ == "__main__":
    main()