import os
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init


def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)


_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time

def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f
def set_random(seed):
    np.random.seed(seed=seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    return

def accuracy(output, target, topk=(1,)):
    batch_size = target.size(0)
    res = []
    if output.shape[1] == 1:
        res.append(torch.mean(torch.where(output*(2*target-1) > 0.0, 100.0, 0.0)))
        return res
        
    maxk = max(topk)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        if batch_size == 0:
            res.append(correct_k.mul_(0.0))
        else:
            res.append(correct_k.mul_(100.0 / batch_size))
    return res


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.cnt += n
        if self.cnt > 0:
            self.avg = self.sum / self.cnt

def string2int(t):
    t = t.replace(" ", "")
    t = t[1:-1].split(",")
    for i in range(len(t)):
        t[i] = int(t[i])
    return t

def exp_set_up(cf):
    if not os.path.exists("exp"):
        os.makedirs("exp")
    if not os.path.exists("data"):
        os.makedirs("data")
    cf["exp_path"] = os.path.join("exp", cf["name"])
    if not os.path.exists(cf["exp_path"]):
        os.makedirs(cf["exp_path"])

    cf["log_path"] = os.path.join(cf["exp_path"], cf["name"]+"_log.txt")
    cf["checkpoint_path"] = os.path.join(cf["exp_path"], "checkpoint")
    if not os.path.exists(cf["checkpoint_path"]):
        os.makedirs(cf["checkpoint_path"])
    cf["test_path"] = os.path.join(cf["exp_path"], "test")
    if not os.path.exists(cf["test_path"]):
        os.makedirs(cf["test_path"])
    loss_path = os.path.join(cf["exp_path"], "loss.pth")

    if cf["data"]["label_function"] is not None:
        if cf["data"]["structure"]:
            cf["fig_path"] = os.path.join(
                "exp", cf["name"], cf["model"]["type"] + "_" + cf["data"]["label_function"] + "_structure")
        else:
            cf["fig_path"] = os.path.join(
                "exp", cf["name"], cf["model"]["type"] + "_" + cf["data"]["label_function"] + "_no_structure")
    else:
        cf["fig_path"] = os.path.join(
            "exp", cf["name"], cf["model"]["type"])
    if not os.path.exists(cf["fig_path"]):
        os.makedirs(cf["fig_path"])

    return loss_path

def print_arg(cf):
    print("============================================\n============================================")
    print('----Print Arguments Setting------')
    print("*** Experiment Name: ", cf["name"], "***")
    for key in cf:
        if type(cf[key]) is dict:
            print('{}:'.format(key))
            for para in cf[key]:
                print('    {:50}:{}'.format(para, cf[key][para]))
            print('\n')
        else:
            print('{:50}:{}'.format(key, cf[key]))
    print()