import os
import random

from ussl_train import USSL
from ussl_eval import USSLEval
import torch
from torch import nn

import torch.backends.cudnn as cudnn
import logging
import argparse

import torch.multiprocessing as mp
import socket

import glob
import torch.distributed as dist
import pdb


def add_learner_params(parser):
    parser.add_argument('--problem', default='simclr', help='The problem to train', choices=['simclr', 'linear_eval'])
    parser.add_argument('--root', default="", type=str, help='log path')
    parser.add_argument('--data_root', default="", type=str, help='data root')
    
    # resume params
    parser.add_argument('--ckpt', default='', help='Optional checkpoint to init the model.')
    parser.add_argument('--resume', action='store_true', help='Resumes from latest checkpoint or ckpt if provided')
    
    # optimizer params
    parser.add_argument('--encoder_lr_schedule', default='warmup-anneal')
    parser.add_argument('--encoder_opt', default='lars', help='Optimizer to use', choices=['sgd', 'adam', 'lars'])
    parser.add_argument('--iters', default=48000, type=int, help='The number of epochs')
    parser.add_argument('--encoder_warmup', default=0, type=float, help='The number of warmup iterations in proportion to \'iters\'')
    parser.add_argument('--encoder_lr', default=0.1, type=float, help='Base learning rate')
    parser.add_argument('--encoder_weight_decay', default=1e-4, type=float, dest='encoder_weight_decay')
    
    parser.add_argument('--classifier_lr_schedule', default='warmup-anneal')
    parser.add_argument('--classifier_opt', default='lars', help='Optimizer to use', choices=['sgd', 'adam', 'lars'])
    parser.add_argument('--classifier_warmup', default=0, type=float, help='The number of warmup iterations in proportion to \'iters\'')
    parser.add_argument('--classifier_lr', default=0.1, type=float, help='Base learning rate')
    parser.add_argument('--classifier_weight_decay', default=1e-4, type=float, dest='classifier_weight_decay')
    
    # logging params
    parser.add_argument('--save_freq', default=10000000000000000, type=int, help='Frequency to save the model')
    parser.add_argument('--log_freq', default=48, type=int, help='Frequency to log')
    
    # parallelizm params:
    parser.add_argument('--dist', default='dp', type=str, help='dp: DataParallel, ddp: DistributedDataParallel', choices=['dp', 'ddp'])
    parser.add_argument('--dist_address', default='127.0.0.1:1234', type=str, help='the address and a port of the main node in the <address>:<port> format')
    parser.add_argument('--rank', default=0, type=int, help='Rank of the node (script launched): 0 for the main node and 1,... for the others')
    parser.add_argument('--world_size', default=1, type=int, help='the number of nodes (scripts launched)')
    parser.add_argument('--sync_bn', default=True, type=bool, help='Syncronises BatchNorm layers between all processes if True')
    
    # arch params
    parser.add_argument('--arch', default='resnet50', help='Encoder architecture')
    parser.add_argument('--proj_dim', default=128, type=int, help='Projection Dimension')
    parser.add_argument('--ndf', default=16, type=int, help='Decoder hidden filters')
    parser.add_argument('--seed', default=-1, type=int, help='Random seed')
    parser.add_argument('--use_mlp_classifier', action='store_true', help='Use an MLP head as classifier')
    parser.add_argument('--use_conv_classifier', action='store_true', help='Use an CONV head as classifier')
    
    # dataloader params
    parser.add_argument('--datasets', nargs = "+", help='train dataset names')
    parser.add_argument('--image_size', default=32, type=int, help='Input image size')
    parser.add_argument('--batch_size', default=1024, type=int, help='The number of unique images in the batch')
    parser.add_argument('--workers', default=2, type=int, help='The number of data loader workers')
    parser.add_argument('--multiplier', default=2, type=int)
    parser.add_argument('--color_dist_s', default=1., type=float, help='Color distortion strength')
    parser.add_argument('--scale_lower', default=0.08, type=float, help='The minimum scale factor for RandomResizedCrop')
    parser.add_argument('--use_color_dist', action='store_true', help='Use color distortion')
    parser.add_argument('--use_rotation', action='store_true', help='Use rotation aug')
    parser.add_argument('--task', help='task to test on', default='')
    parser.add_argument('--k_shot', nargs = "+", help = 'Number of few-shot samples per dataset. None -> use full dataset')
    parser.add_argument('--combine_datasets', action='store_true', help='Combine datsets during training')
    parser.add_argument('--ignore_domain_labels', action='store_true', help='Do not use domain labels')
    
    # loss params
    parser.add_argument('--temperature', default=0.1, type=float, help='Temperature in the NTXent loss')
    parser.add_argument('--noise_dim', default=0, type=int, help='noise dimension for generator')
    parser.add_argument('--ussl_reg_aug', default=1, type=float, help='Hierarchical SimCLR reg for aug')
    parser.add_argument('--ussl_reg_dataset', default=0, type=float, help='Hierarchical SimCLR reg for dataset')
    parser.add_argument('--ussl_reg_ood', default=0, type=float, help='Hierarchical SimCLR reg for ood')
    parser.add_argument('--simclr_warmup', default=100000000, type=int, help='warmup for simclr')
    parser.add_argument('--clustering_freq', default=100000000, type=int, help='clustering frequency')
    parser.add_argument('--remove_outliers', action='store_true', help='remove cluster outliers during training')
    parser.add_argument('--restart_training', action='store_true', help='restart training after clusters are found')


def main():
    
    parser = argparse.ArgumentParser(description='PyTorch SimCLR')
    add_learner_params(parser)

    args = parser.parse_args()

    if args.seed != -1:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    ngpus_per_node = torch.cuda.device_count()
    if args.dist == 'ddp':
        args.world_size = ngpus_per_node * args.world_size
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        main_worker(0, ngpus_per_node, args)

def set_checkpoint(args):
    checkpoint = args.ckpt
    if len(checkpoint) == 0:
        checkpoints = glob.glob(f"{args.root}/checkpoint*.pth.tar")
        if len(checkpoints) > 0:
            checkpoint = max(checkpoints, key = os.path.getmtime)
            args.ckpt = checkpoint

            
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    
    logging.basicConfig(filename=os.path.join(args.root, f'training.log' if args.problem == 'simclr' else f"eval_{args.task}{'_mlp' if args.use_mlp_classifier else ''}{'_conv' if args.use_conv_classifier else ''}.log" ), level=logging.INFO)
    logging.info(args)
    
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    
    if args.dist == 'ddp':
        args.rank = args.rank * ngpus_per_node + gpu 
        dist.init_process_group(
            backend='nccl',
            init_method='tcp://%s' % args.dist_address,
            world_size=args.world_size,
            rank=args.rank,
        )
    
    torch.cuda.set_device(args.gpu)
    
    simclr_model = USSL(args)
         
    if args.dist == 'ddp':
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
        
        if args.rank == 0:
            print(f'===> {dist.get_world_size()} GPUs total; batch_size={args.batch_size} per GPU')

        print(f'===> Proc {dist.get_rank()}/{dist.get_world_size()}@{socket.gethostname()}', flush=True)
            
    cudnn.benchmark = True
        
    if args.resume:
        set_checkpoint(args)
        if len(args.ckpt) > 0:
            simclr_model.load()
            
    if args.problem == 'linear_eval':
        set_checkpoint(args)
        classifier = USSLEval(simclr_model, args)
        classifier.train()
    else:
        simclr_model.train()

if __name__ == '__main__':
    main()