import argparse
import datetime
import os
import random

import numpy as np
import torch
import yaml

from data_clevrtex import clevrtex
from model import get_model
from run_model import train_model
from torch.utils.data import DataLoader
from calflops import calculate_flops
import pdb

def get_datasets(config,batch_size):
    train_dataset = eval(config['dataset'])(root=config['path_data'],phase='train', img_size=config['image_size'])
    val_dataset = eval(config['dataset'])(root=config['path_data'], phase='valid', img_size=config['image_size'])
    loader_kwargs = {
        'batch_size': batch_size,
        'shuffle': True,
        'num_workers': 0,
        'pin_memory': True,
        'drop_last': True,
        }
    train_loader = DataLoader(train_dataset, sampler=None, **loader_kwargs)

    loader_kwargs = {
    'batch_size': batch_size,
    'shuffle': False,
    'num_workers': 0,
    'pin_memory': True,
    'drop_last': True,
    }
    val_loader = DataLoader(val_dataset, sampler=None, **loader_kwargs)
    return train_loader,val_loader
def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--path_config')
    parser.add_argument('--path_data')
    parser.add_argument('--path_pretrain')
    parser.add_argument('--folder_log')
    parser.add_argument('--folder_out')
    parser.add_argument('--model_name')
    parser.add_argument('--dataset')
    parser.add_argument('--timestamp')
    parser.add_argument('--num_tests', type=int)
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--train', action='store_true')
    parser.add_argument('--resume',default = False , action='store_true')
    parser.add_argument('--use_timestamp', action='store_true')
    parser.add_argument('--file_ckpt', default='ckpt.pth')
    parser.add_argument('--file_model', default='model.pth')
    parser.add_argument('--use_dp', default=False, action='store_true')
    args = parser.parse_args()
    
    with open(args.path_config) as f:
        config = yaml.safe_load(f)
    for key, val in args.__dict__.items():
        if key not in config or val is not None:
            config[key] = val
    if config['debug']:
        config['ckpt_intvl'] = 1
    if config['resume']:
        config['train'] = True
    if config['timestamp'] is None:
        config['timestamp'] = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    if config['use_timestamp']:
        for key in ['folder_log', 'folder_out']:
            config[key] = os.path.join(config[key], config['timestamp'])
    if 'seed' not in config:
        config['seed'] = random.randint(0, 0xffffffff)
    print('seed',config['seed'])

    return config


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    return


def main():
    config = get_config()
    if config['train'] and not config['resume']:
        for key in ['folder_log', 'folder_out']:
            if os.path.exists(config[key]):
                print('{} has been created'.format(config[key]))
            else:
                os.makedirs(config[key])
    set_seed(config['seed'])

    image_size=config.get('image_size',128)
    config['image_size']=image_size
    data_loaders = {}
    train_loader,val_loader=get_datasets(config,config['batch_size'])    
    
    data_loaders['train']=train_loader
    data_loaders['valid']=val_loader
    config['image_shape'] = [3,config['image_size'],config['image_size']]
    
    with open(os.path.join(config['folder_out'], 'config.yaml'), 'w') as f:
        yaml.safe_dump(config, f)
    net = get_model(config)

    model_size=0
    for param in net.parameters():
        model_size += param.data.nelement()
    print('Model params: %.0f ' % (model_size))
    

    if config['train']:
        train_model(config, data_loaders, net)


if __name__ == '__main__':
    main()
