import argparse

import numpy as np

perturb_types = ['contrast', 'defocus_blur', 'elastic_transform', 'gaussian_blur', 'gaussian_noise']


def count_detection_req_sample(datas, batch_size, n_perturb_sample, thr=0.8):
  res = []
  for d in datas:
    found = np.where(d.mean(axis=0) >= thr)[0]
    if len(found) == 0:
      res.append(-1)
    else:
      res.append((found[0]+1) * batch_size - n_perturb_sample)

  return res


def cnt_fp(datas, batch_size, timepoints):
  res = []

  for d in datas:
    fpr_cnts = {}
    for sample_point in timepoints:
      idx = sample_point // batch_size
      fpr_cnts[sample_point] = d[:, idx].mean()
    res.append(fpr_cnts)
  return res


def change_perturb_type(old_name, sch_type):
  new_name = old_name.replace('_' + sch_type, '')
  new_name = new_name.replace('_', ' ')
  new_name = new_name.title()
  new_name = new_name[:-1]

  return new_name


def convert_to_text(cnt_map):
  res = {}
  for key in cnt_map:
    cnts = cnt_map[key]
    min_cnt = np.min(cnts)
    txt_cnts = []
    for v in cnts:
      if v == min_cnt:
        txt_cnts.append('\\textbf{{{}}}'.format(v))
      else:
        txt_cnts.append('{}'.format(v))
    res[key] = txt_cnts
  return res


def convert_to_text_fp(fp_map, alpha):
  res = {}
  for key in fp_map:
    fprs = fp_map[key]
    txt_fprs = []
    for res_by_alg in fprs:
      item = {}
      for s in res_by_alg:
        if res_by_alg[s] > alpha:
          item[s] = ('\\textbf{{{:.2f}}}'.format(res_by_alg[s]*100))
        else:
          item[s] = '{:.2f}'.format(res_by_alg[s]*100)
      txt_fprs.append(item)
    res[key] = txt_fprs
  return res


def write_latex(result_data, table_title):
  str = '\\begin{table}\n'
  str += '\\caption{{{}}}\n'.format(table_title)
  str += '\\label{tab:}\n'
  str += '\\begin{center}\n'
  str += '\\begin{tabular}{' + 'l' + ('c' * (len(result_data))) + '}\n'
  str += '\\bf Algorithms & ' + ' & '.join(['\\bf {}'.format(change_perturb_type(c, args.sch)) for c in result_data]) + '\n'
  str += '\\\\ \\hline \\\\\n'
  str += 'Ours & ' + ' & '.join(['{}'.format(result_data[key][0]) for key in result_data]) + '\\\\\n'
  str += 'H2 & ' + ' & '.join(['{}'.format(result_data[key][1]) for key in result_data]) + '\\\\\n'
  str += 'H4 & ' + ' & '.join(['{}'.format(result_data[key][2]) for key in result_data]) + '\\\\\n'
  str += '\\end{tabular}\n'
  str += '\\end{center}\n'
  str += '\\end{table}'
  return str


def write_fpr_latex(result_data, timepoints, table_title):
  str = '\\begin{table}\n'
  str += '\\caption{{{}}}\n'.format(table_title)
  str += '\\label{tab:}\n'
  str += '\\begin{center}\n'
  str += '\\begin{tabular}{' + 'l' + ('c' * (len(timepoints))) + '}\n'
  str += '\\bf Algorithms & ' + ' & '.join(['\\bf {}'.format(s) for s in timepoints]) + '\n'
  str += '\\\\ \\hline \\\\\n'
  str += 'Ours & ' + ' & '.join([v for v in list(result_data[list(result_data.keys())[0]][0].values())]) + '\\\\\n'
  str += 'H2 & ' + ' & '.join([v for v in list(result_data[list(result_data.keys())[0]][1].values())]) + '\\\\\n'
  str += 'H5 & ' + ' & '.join([v for v in list(result_data[list(result_data.keys())[0]][2].values())]) + '\\\\\n'
  str += '\\end{tabular}\n'
  str += '\\end{center}\n'
  str += '\\end{table}'
  return str


def main(args):
  hs = [-1, 2, 5]

  schs = ['dogs_{}'.format(args.sch)]
  if args.sch == 'c3':
    n_perturb_start = 250
    timepoints = [50, 100, 150, 200]  # in # samples
    table_title = 'Multiple shift change with $w={}$'.format(args.W // args.batch_size)
  elif args.sch == 'gradinc' or args.sch == 'gradincdec':
    n_perturb_start = 200
    timepoints = [50, 100, 150, 200]  # in # samples
    if args.sch == 'gradinc':
      table_title = 'Gradually Increasing with $w={}$'.format(args.W // args.batch_size)
    else:
      table_title = 'Gradually Increasing-Decreasing with $w={}$'.format(args.W // args.batch_size)

  cnt_map = {}
  fp_map = {}
  for sch in schs:
    print(sch)
    datas_req = []
    datas_fpr = []
    for h in hs:
      fn_req = './results/sch_dogs_batch_{}/result_ours_sch{}_H{}_alpha{}_M{}_R{}_W{}_lr{:.4f}.csv'.format(
        args.batch_size,
        sch,
        h,
        args.alpha,
        args.M,
        args.R_for_req,
        args.W,
        args.lr)

      data = np.loadtxt(fn_req, delimiter=',').astype(np.int)
      datas_req.append(data)

      fn_fpr = './results/sch_dogs_batch_{}/result_ours_sch{}_H{}_alpha{}_M{}_R{}_W{}_lr{:.4f}.csv'.format(
        args.batch_size,
        sch,
        h,
        args.alpha,
        args.M,
        args.R_for_fpr,
        args.W,
        args.lr)

      data = np.loadtxt(fn_fpr, delimiter=',').astype(np.int)
      datas_fpr.append(data)

    n_samples_det = count_detection_req_sample(datas_req, args.batch_size, n_perturb_start, thr=0.70)
    cnt_map[sch] = n_samples_det
    fprs = cnt_fp(datas_fpr, args.batch_size, timepoints)
    fp_map[sch] = fprs
  latex_str = write_latex(convert_to_text(cnt_map), table_title)
  latex_str_fp = write_fpr_latex(convert_to_text_fp(fp_map, args.alpha), timepoints, table_title)

  with open('tbl_latex_{}_R{}_R{}_W{}_lr{}.txt'.format(args.sch,
                                                       args.R_for_req,
                                                       args.R_for_fpr,
                                                       args.W,
                                                       args.lr), 'w') as f:
    f.write(latex_str)
    f.write('\n\n')
    f.write(latex_str_fp)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()

  parser.add_argument('--batch_size', default=10, type=int)
  parser.add_argument('--sch', default='gradincdec', type=str)
  parser.add_argument('--alpha', default=0.01, type=float)
  parser.add_argument('--M', default=1000, type=int)
  parser.add_argument('--R_for_req', default=100, type=int)
  parser.add_argument('--R_for_fpr', default=500, type=int)
  parser.add_argument('--W', default=100, type=int)
  parser.add_argument('--lr', default=0.01, type=float)

  args = parser.parse_args()

  main(args)
