import argparse

import matplotlib.pyplot as plt
import numpy as np


def load_data(fn, dtype=np.int):
  data = np.loadtxt(fn, delimiter=',').astype(dtype)
  return data


def draw_sum_plot(data, acc_data, labels, batch_size, sch, M, R, W, lr, colors=None):
  plt.rcParams.update({'font.size': 15})
  plt.rcParams.update({'xtick.labelsize': 12})
  plt.rcParams.update({'ytick.labelsize': 12})

  fig, ax1 = plt.subplots()
  ax2 = ax1.twinx()

  for i, (d, a) in enumerate(zip(data, acc_data)):
    data_sum = d.mean(axis=0) * 100
    if colors is not None and len(colors) == len(data):
      ax1.plot(np.arange(d.shape[1]) * batch_size, data_sum, color=colors[i])
    else:
      ax1.plot(np.arange(d.shape[1]) * batch_size, data_sum)

    ax2.plot(np.arange(d.shape[1]) * batch_size, a.mean(axis=0)*100, color='r', linestyle=':')

  ax1.set_ylim([-2, 102])
  plt.grid()

  ax1.set_xlabel("Number of samples")
  ax1.set_ylabel("Detection Rate over repetition (%) ")
  ax1.legend(labels, loc='upper left')

  ratios = np.zeros(data[0].shape[1] * batch_size)
  ratios[2000:4000] = 20
  ratios[4000:6000] = 40
  ratios[6000:8000] = 60
  ratios[8000:] = 80

  ax2.plot(np.arange(data[0].shape[1] * batch_size), ratios, color='k', linestyle='--')
  ax2.set_ylabel('Shifted sample ratio / Accuracy (%)')
  ax2.set_ylim([-2, 102])

  output_fn = 'sch_sum_{}_B{}_M{}_R{}_W{}_lr{:.4f}.png'.format(sch.lower(),
                                                               batch_size,
                                                               M,
                                                               R,
                                                               W,
                                                               lr)
  plt.savefig(output_fn)


def main(args):
  window_size = [args.W]
  hs = [-1, 2, 5]
  data = []
  acc_data = []

  for h in hs:
   for w in window_size:
      fn = './results/sch_gradinc_batch_{}/result_ours_sch{}_H{}_alpha{}_M{}_R{}_W{}_lr{:.4f}.csv'.format(args.batch_size,
                                                                                                          args.sch,
                                                                                                          h,
                                                                                                          args.alpha,
                                                                                                          args.M,
                                                                                                          args.R,
                                                                                                          w,
                                                                                                          args.lr)

      d = load_data(fn)
      data.append(d)

      fn = './results/sch_gradinc_batch_{}/result_acc_ours_sch{}_H{}_alpha{}_M{}_R{}_W{}_lr{:.4f}.csv'.format(
        args.batch_size,
        args.sch,
        h,
        args.alpha,
        args.M,
        args.R,
        w,
        args.lr)

      a = load_data(fn, dtype=np.float32)
      acc_data.append(a)
  labels = []
  for h in hs:
    for w in window_size:
      if w == -1:
        lbl = 'No Window'
        lbl += ' (H{})'.format(h) if h != -1 else ' (Ours)'
      else:
        # lbl = '$w$={}'.format(w//args.batch_size)
        lbl = 'H{}'.format(h) if h != -1 else 'Ours'
      labels.append(lbl)

  assert args.R == data[0].shape[0], "R should match to the number of experiments."

  draw_sum_plot(data, acc_data, labels, args.batch_size, args.sch, args.M, args.R, args.W, args.lr, colors=['green', 'darkorange', 'blueviolet'])
  # plt.show()
  plt.close('all')


if __name__ == '__main__':
  parser = argparse.ArgumentParser()

  parser.add_argument('--batch_size', default=10, type=int)
  parser.add_argument('--sch', default='gaussian_noise2_gradinc', type=str)
  parser.add_argument('--alpha', default=0.01, type=float)
  parser.add_argument('--M', default=10000, type=int)
  parser.add_argument('--R', default=100, type=int)
  parser.add_argument('--W', default=100, type=int)
  parser.add_argument('--lr', default=0.001, type=float)

  args = parser.parse_args()

  main(args)
