import pdb
import numpy as np
import torch.utils.data as data
import utils
from loss import Total_loss
from options import *
from config import *
from train import *
from test import *
from model import *
from search import *
from tensorboard_logger import Logger
from thumos_features import *
from tqdm import tqdm


if __name__ == "__main__":
    args = parse_args()
    if args.debug:
        pdb.set_trace()

    if os.path.isfile(args.config):
        cfg = load_config(args.config)
    else:
        raise ValueError("Config file does not exist.")
    print(cfg)

    cfg = init_args(cfg)

    config = Config(cfg)
    worker_init_fn = None
   
    if config.seed >= 0:
        utils.set_seed(config.seed)
        worker_init_fn = np.random.seed(config.seed)

    utils.save_config(config, os.path.join(config.output_path, "config.txt"))

    net = Model(config.len_feature, config.num_classes, config.r_act)
    net = net.cuda()

    train_loader = data.DataLoader(
        build_dataset(dataset_name=config.dataset_name, data_path=config.data_path, mode='train',
                        modal=config.modal, feature_fps=config.feature_fps,
                        num_segments=config.num_segments, sampling=config.sampling_type,
                        supervision='point', seed=config.seed),
            batch_size=config.real_batch_size,
            shuffle=True, num_workers=config.num_workers,
            worker_init_fn=worker_init_fn)

    test_loader = data.DataLoader(
        build_dataset(dataset_name=config.dataset_name, data_path=config.data_path, mode='test',
                      modal=config.modal, feature_fps=config.feature_fps,
                      num_segments=-1, sampling='uniform',
                      supervision='point', seed=config.seed),
            batch_size=1,
            shuffle=False, num_workers=config.num_workers,
            worker_init_fn=worker_init_fn)

    dataset_name = config.dataset_name
    test_info = utils.build_test_info(dataset_name)

    best_mAP = -1

    criterion = Total_loss(config.lambdas, config.th_similar_min, config.th_different_max)

    optimizer = torch.optim.Adam(net.parameters(), lr=config.lr,
        betas=(0.9, 0.999), weight_decay=config.weight_decay)

    logger = Logger(config.log_path)
    for step in tqdm(
            range(1, config.num_iters + 1),
            total = config.num_iters,
            dynamic_ncols = True, desc='It\'s a train', position=0
        ):
        if 'thumos14' in dataset_name:
            if (step - 1) % (len(train_loader) // config.batch_size) == 0:
                loader_iter = iter(train_loader)
            train_thumos(net, config, loader_iter, optimizer, criterion, logger, step)
        elif 'activitynet13' in dataset_name:
            train_activity(net, config, train_loader, optimizer, criterion, logger, step)
        elif 'GTEA' in dataset_name:
            if (step - 1) % (len(train_loader) // config.batch_size) == 0:
                loader_iter = iter(train_loader)
            train_thumos(net, config, loader_iter, optimizer, criterion, logger, step)
        elif 'BEOID' in dataset_name:
            if (step - 1) % (len(train_loader) // config.batch_size) == 0:
                loader_iter = iter(train_loader)
            train_thumos(net, config, loader_iter, optimizer, criterion, logger, step)

        if step % config.search_freq == 0:
            optimal_sequence_search(net, config, logger, train_loader, step)
        
        if step % config.val_per_epoch == 0:
            test(net, config, logger, test_loader, test_info, step)

            best_mAP = utils.compare_save_results(config, net, dataset_name, test_info, best_mAP)