# -*- coding: utf-8 -*-
from os.path import join
from itertools import groupby
from functools import reduce

import torch
import torch.distributed as dist

from cnn.utils.opfiles import build_dirs
import numpy as np
import random
import os 

def set_checkpoint(args):
    args.checkpoint_root = join(
        args.checkpoint, args.data, args.arch,
        args.device if args.device is not None else '', args.timestamp)
    args.checkpoint_dir = join(args.checkpoint_root, str(args.cur_rank))
    args.save_some_models = args.save_some_models.split(',')

    # if the directory does not exists, create them.
    build_dirs(args.checkpoint_dir)


def set_lr(args):
    args.lr_change_epochs = [
        int(l) for l in args.lr_decay_epochs.split(',')] \
        if args.lr_decay_epochs is not None \
        else None

    #lr tuning
    args.learning_rate_per_sample = 0.1 / args.batch_size
    args.learning_rate = \
        args.learning_rate_per_sample * args.batch_size * args.world_size \
        if args.lr_scale else args.lr
    args.old_learning_rate = args.learning_rate


def set_conf(args):
    # global conf.
    # configure world.

    os.environ['PYTHONHASHSEED'] = str(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.manual_seed)
        torch.cuda.manual_seed_all(args.manual_seed)
    np.random.seed(args.manual_seed)  # Numpy module.
    random.seed(args.manual_seed)  # Python random module.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    # local conf.
    args.local_index = 0
    args.best_prec1 = 0
    args.best_epoch = []
    args.val_accuracies = []

    args.ranks = list(range(args.world_size))
    args.cur_rank = dist.get_rank()
    if args.device == 'gpu':
        torch.cuda.set_device(args.cur_rank)


    # define checkpoint for logging.
    set_checkpoint(args)

    # define learning rate and learning rate decay scheme.
    set_lr(args)
