from __future__ import print_function
import os, sys
import numpy as np
import chainer
from chainer.backends import cuda
import chainer.functions as F
import matplotlib
# Disable interactive backend
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import random


def set_random_seed(seed):
    # set Python random seed
    random.seed(seed)
    # set NumPy random seed
    np.random.seed(seed)
    # set Chainer(CuPy) random seed
    cuda.cupy.random.seed(seed)

def record_setting(out):
    """Record scripts and commandline arguments"""
    out = out.split()[0].strip()
    if not os.path.exists(out):
        os.mkdir(out)
    # subprocess.call("cp *.py %s" % out, shell=True)
    with open(out + "/command.txt", "w") as f:
        f.write(" ".join(sys.argv) + "\n")


def mean_accuracy(logits, y):
    return F.binary_accuracy(logits, y)


def pretty_print(*values):
    col_width = 13

    def format_val(v):
        if not isinstance(v, str):
            v = np.array2string(v, precision=5, floatmode='fixed')
        return v.ljust(col_width)

    str_values = [format_val(v) for v in values]
    print("   ".join(str_values))




def plot_acc(tr, ts, fname, out):
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    tr_mean = np.mean(tr, axis=0)
    tr_std = np.std(tr, axis=0)
    ts_mean = np.mean(ts, axis=0)
    ts_std = np.std(ts, axis=0)

    ax.fill_between(range(len(tr[0])), tr_mean + tr_std,
                    tr_mean - tr_std, facecolor='red', alpha=0.5)
    ax.fill_between(range(len(ts[0])), ts_mean + ts_std,
                    ts_mean - ts_std, facecolor='blue', alpha=0.75)

    ax.plot(range(len(tr[0])), tr_mean, color='red',
            linestyle='solid', linewidth=3, label='train accuracy')
    ax.plot(range(len(tr[0])), ts_mean, color='blue',
            linestyle='solid', linewidth=3, label='test accuracy')
    plt.ylim(0.0, 1.0)
    ax.set_title('IRM')
    ax.set_xlabel('Iteration (full batch)')
    ax.set_ylabel('Accuracy')
    ax.legend()
    fig.savefig('./{}/'.format(out) + fname + "_acc.png")
