# -*- coding: utf-8 -*-
import platform

import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.distributed as dist
import cnn
import os


import builtins
from track import TrackEpochs
_track = TrackEpochs()
builtins.tracker = _track


from arguments import get_args, log_args, get_cnn_args, get_lstm_args, tutorial_args
from cnn.utils.log import log, configure_log
from cnn.utils.set_conf import set_conf
from cnn.models.create_model import create_model
from cnn.runs.distributed_running import train_and_validate as train_val_op
from cnn.runs.distributed_running import do_validate
from torch.multiprocessing import Process
import pdb
from cnn.dataset.data import create_dataset



def main(args):
    torch.autograd.set_detect_anomaly(True)
    if args.type == 'getting_started':
        args = tutorial_args()
        resnet18_cifar10(args)
    elif args.type == 'cnn':
        args = get_cnn_args()
        size = args.world_size
        processes = []
        for rank in range(size):
            p = Process(target=init_processes, args=(rank, size, run))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()
    elif args.type == 'lstm':
        args = get_lstm_args()
        train_lstm(args)
    elif args.type == 'eval_cnn':
        args = get_cnn_args()
        checkpoint_path = args.checkpoint
        size = args.world_size
        processes = []
        for rank in range(size):
            p = Process(target=init_processes, args=(rank, size, run_eval))
            p.start()
            processes.append(p)
        for p in processes:
            p.join()

def run(rank, size):
    """ Distributed Synchronous SGD Example """
    args = get_cnn_args()
    set_conf(args)
    print('set_conf...')
    # create model and deploy the model.
    model, criterion, optimizer = create_model(args)
    # config and report.
    configure_log(args)
    print('configure_log...')
    log_args(args)
    print('log_args...')

    device = 'GPU-'+ str(torch.cuda.current_device()) if args.device != "cpu" else "cpu"

    log(
        'Rank {} {}'.format(
            args.cur_rank,
            device
            # args.cur_gpu_device
            )
        )

    train_val_op(args, model, criterion, optimizer)

def run_eval(rank, size):
    args = get_cnn_args()
    checkpoint = torch.load(args.checkpoint)
    #pdb.set_trace()

    args = checkpoint['arguments']

    model, criterion, optimizer = create_model(args)
    configure_log(args)
    print('configure_log...')
    log_args(args)
    print('log_args...')

    device = 'GPU-'+ str(torch.cuda.current_device()) if args.device != "cpu" else "cpu"

    log(
        'Rank {} {}'.format(
            args.cur_rank,
            device
            # args.cur_gpu_device
            )
        )

    train_loader, val_loader = create_dataset(args)
    do_validate(args, val_loader, model, optimizer, criterion, save=True)
    #train_val_op(args, model, criterion, optimizer)


def init_processes(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)


if __name__ == "__main__":
    args = get_args()
    main(args)
