import argparse
import logging
import os

import time
import random

import numpy as np
import torch.cuda
from omegaconf import OmegaConf
from model.model import PatchTraj
from tool.tools import generate_pad, get_dct_matrix
from tool.evaluate import val_model
from tool.evaluate_nba import val_nba
from tool.evaluate_jrdb import val_jrdb

torch.cuda.set_device(0)


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path',          type=str,       default='./data')
    parser.add_argument('--mode',               type=str,       default='', help=" train or test ")
    parser.add_argument('--log_dir',            type=str,       default='./logs')
    parser.add_argument('--manual_seed',        type=int,       default=0)
    parser.add_argument('--vis_control',        type=bool,      default=False)
    parser.add_argument('--cfg',                type=str,       default='', help="eth, sdd, nba or jrdb")
    parser.add_argument('--model_ckpt',         type=str,       default='')

    args = parser.parse_args()

    def prepare_seed(rand_seed):
        np.random.seed(rand_seed)
        random.seed(rand_seed)
        torch.manual_seed(rand_seed)
        torch.cuda.manual_seed_all(rand_seed)
        torch.backends.cudnn.deterministic = True  # Add deterministic algorithm
        torch.backends.cudnn.benchmark = False  # Turn off benchmark optimization

    # initialize the seed
    prepare_seed(args.manual_seed)

    cfg = OmegaConf.load(args.cfg)
    cfg.data_path = args.data_path
    cfg.mode = args.mode
    cfg.log_dir = args.log_dir
    cfg.manual_seed = args.manual_seed
    cfg.vis_control = args.vis_control
    cfg.idx_pad, cfg.zero_index = generate_pad(cfg.padding, cfg.obs_len, cfg.pred_len)

    # get DCT matrix
    dct_m, idct_m = get_dct_matrix(cfg.obs_len + cfg.pred_len)

    # initialize model
    model = PatchTraj(
        input_feats=2 * 3,  # 2 means x, y
        obs_len=cfg.obs_len,
        pred_len=cfg.pred_len,
        patch_size=cfg.patch_size,
        patch_list=cfg.patch_list,
        stride=cfg.stride,
        num_frames=cfg.n_pre,
        num_layers=cfg.num_layers,
        num_heads=cfg.num_heads,
        latent_dim=cfg.latent_dims,
        num_experts=cfg.num_experts,
        patch_embed=cfg.patch_embed,
        dropout=cfg.dropout,
        num_sample=cfg.num_sample,
        dynamic_patch=cfg.dynamic_patch
    )

    # Count total trainable parameter number
    model_parameters = filter(lambda a: a.requires_grad, model.parameters())
    parameters = sum([np.prod(a.size()) for a in model_parameters])
    print('>>>>>> Total trainable parameters are: {:.2f}M'.format(parameters / 1000000.0))

    # Log prepare
    log_path = cfg.log_dir + '/' + cfg.dataset_type
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    log = open(os.path.join(log_path, 'log.txt'), 'a+')

    OmegaConf.save(cfg, os.path.join(log_path, 'configs.yaml'))

    # Test stage
    if cfg.mode == 'test':
        ckpt = torch.load(args.model_ckpt, map_location='cpu')
        model.load_state_dict(ckpt['model_dict'])
        if cfg.dataset_type == 'ETH':
            # scene_names = ['eth', 'hotel', 'univ', 'zara1', 'zara2']
            scene_names = ['eth']  # for visualization
            for scene in scene_names:
                scene = scene.split()
                val_model(cfg, model, dct_m, idct_m, scene)
        elif cfg.dataset_type == 'SDD':
            val_model(cfg, model, dct_m, idct_m)
        elif cfg.dataset_type == 'NBA':
            val_nba(cfg, model)
        elif cfg.dataset_type == 'JRDB':
            val_jrdb(cfg, model)

