import argparse

import matplotlib.pyplot as plt
import numpy as np


def load_data(fn, dtype=np.int):
  data = np.loadtxt(fn, delimiter=',').astype(dtype)
  return data


def draw_sum_plot(data, labels, batch_size, sch, M, R, W, lr, colors=None):
  plt.rcParams.update({'font.size': 15})
  plt.rcParams.update({'xtick.labelsize': 12})
  plt.rcParams.update({'ytick.labelsize': 12})
  fig, ax1 = plt.subplots()
  ax2 = ax1.twinx()

  for i, d in enumerate(data):
    data_sum = d.mean(axis=0)*100
    if colors is not None and len(colors) == len(data):
      ax1.plot(np.arange(d.shape[1]) * batch_size, data_sum, color=colors[i])
    else:
      ax1.plot(np.arange(d.shape[1]) * batch_size, data_sum)

  ax1.set_ylim([-2, 102])
  plt.grid()

  ax1.set_xlabel("Number of samples")
  ax1.set_ylabel("Detection Rate over repetition (%) ")
  ax1.legend(labels, loc="upper left")

  ratios = np.zeros(data[0].shape[1] * batch_size)
  if sch == 'dogs_c3':
    ratios[250:500] = 100
    ratios[750:] = 100
  elif sch == 'dogs_gradinc':
    ratios[200:400] = 20
    ratios[400:600] = 40
    ratios[600:800] = 60
    ratios[800:] = 80
  elif sch == 'dogs_gradincdec':
    ratios[200:400] = 40
    ratios[400:600] = 80
    ratios[600:800] = 40

  ax2.plot(np.arange(data[0].shape[1] * batch_size), ratios, color='k', linestyle='--')
  ax2.set_ylabel('Shifted sample ratio (%)')
  ax2.set_ylim([-2, 102])

  output_fn = 'sch_sum_{}_B{}_M{}_R{}_W{}_lr{:.4f}.png'.format(sch.lower(),
                                                               batch_size,
                                                               M,
                                                               R,
                                                               W,
                                                               lr)
  plt.savefig(output_fn)


def main(args):
  window_size = [args.W]
  hs = [-1, 2, 5]
  data = []

  for h in hs:
   for w in window_size:
      fn = './results/sch_dogs_batch_{}/result_ours_sch{}_H{}_alpha{}_M{}_R{}_W{}_lr{:.4f}.csv'.format(args.batch_size,
                                                                                                       args.sch,
                                                                                                       h,
                                                                                                       args.alpha,
                                                                                                       args.M,
                                                                                                       args.R,
                                                                                                       w,
                                                                                                       args.lr)
      d = load_data(fn)
      data.append(d)

  labels = []
  for h in hs:
    for w in window_size:
      if w == -1:
        lbl = 'No Window'
        lbl += 'H{}'.format(h) if h != -1 else 'Ours'
      else:
        # lbl = '$w$={}'.format(w // args.batch_size)
        lbl = 'H{}'.format(h) if h != -1 else 'Ours'
      labels.append(lbl)

  assert args.R == data[0].shape[0], "R should match to the number of experiments."

  draw_sum_plot(data, labels, args.batch_size, args.sch, args.M, args.R, args.W, args.lr, colors=['green', 'darkorange', 'blueviolet'])
  # plt.show()
  plt.close('all')


if __name__ == '__main__':
  parser = argparse.ArgumentParser()

  parser.add_argument('--batch_size', default=10, type=int)
  parser.add_argument('--sch', default='dogs_gradincdec', type=str)
  parser.add_argument('--alpha', default=0.01, type=float)
  parser.add_argument('--M', default=1000, type=int)
  parser.add_argument('--R', default=100, type=int)
  parser.add_argument('--W', default=100, type=int)
  parser.add_argument('--lr', default=0.01, type=float)

  args = parser.parse_args()

  main(args)
