import time
import datetime
import argparse
import warnings
from sklearn.exceptions import ConvergenceWarning

import torch
from torch.utils.data import DataLoader

from algorithms.wrapper import get_algorithm
from utils import Logger
from data import SelfSupDataset, EpisodicDataset, CustomTransform

def main(args):
    device = torch.device(f"cuda:{args.gpu_id}")
    if args.debug:        
        args.training_epochs = 1
        args.exp_name = 'debug'
        args.val_episodes = 10
        args.test_episodes = 10

    # data
    train_loader = DataLoader(
        SelfSupDataset(),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=args.num_workers,
        pin_memory=True
    )    
    meta_val_ds = EpisodicDataset('mini', 'meta_val')
    meta_test_ds = EpisodicDataset('mini', 'meta_test')
    args.transform = CustomTransform('mini', args.img_size)
    
    # save name
    now = time.time()
    date = datetime.datetime.fromtimestamp(int(now)).strftime('%Y-%m-%d,%H:%M:%S')
    args.save_dir = f'./exp/{args.exp_name}/{date}'

    # Logger
    logger = Logger(
        args.exp_name,
        save_dir=args.save_dir,
        print_every=1,
        save_every=100,
        total_step=args.training_epochs,
        print_to_stdout=True,
        wandb_project_name=f'K-um-5shot-{args.model}',
        wandb_config=args
    )    

    # algorithm
    algo = get_algorithm(args.algorithm)

    # outer loop
    logger.start()
    algo.run(args, train_loader, meta_val_ds, meta_test_ds, device, logger)
    logger.finish()

if __name__ == "__main__":
    parser = argparse.ArgumentParser('Unsupervised Meta-learning')

    # Directory Argument
    parser.add_argument('--exp-name', type=str, required=True)

    # Data Argument
    parser.add_argument('--pre-img-size', type=int, default=224)
    parser.add_argument('--img-size', type=int, default=84)

    # Algorithm Argument
    parser.add_argument('--algorithm', type=str, required=True)
    parser.add_argument('--repeat-augmentations', type=int, default=8)
       
    # Model Argument    
    #parser.add_argument('--model', type=str, default='conv5_64')
    parser.add_argument('--model', type=str, default='resnet18')
    parser.add_argument('--set-model', type=str, default='SAB(ln=False),DeepPooler')
    parser.add_argument('--num-heads', type=int, default=4)
    parser.add_argument('--beta', type=float, default=1.)
    
    # Training Argument
    parser.add_argument('--batch-size', type=int, default=64)    
    parser.add_argument('--num-workers', type=int, default=8)
    parser.add_argument('--training-epochs', type=int, default=400)    
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--lr-scheduling', action='store_true')

    # Evaluation Argument
    parser.add_argument('--way', type=int, default=5)
    parser.add_argument('--query', type=int, default=15)
    parser.add_argument('--test-every', type=int, default=10)
    parser.add_argument('--val-episodes', type=int, default=100)
    parser.add_argument('--test-episodes', type=int, default=1000)
    
    # System Argument
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--gpu-id', type=int, default=0)
    args = parser.parse_args()

    warnings.filterwarnings(action='ignore', category=ConvergenceWarning)

    main(args)