#! /home/apd10/anaconda3/bin/python
import torch
torch.manual_seed(0)
from torch import nn
import torch.multiprocessing as mp
import torch.utils.data
from torch.autograd import Variable
import numpy as np
from os import path
import os

import pdb
import argparse
from os.path import dirname, abspath, join
import glob
cur_dir = dirname(abspath(__file__))
import yaml
from main_modules.Loop import *
from main_modules.Stats import *
from main_modules.RaceSketch import *
from main_modules.RaceGenSketch import *
from main_modules.DataWriter import *

parser = argparse.ArgumentParser()
parser.add_argument('--sketch', action="store", dest="sketch", type=str, default=None, required=True,
                    help="config to setup the training")

parser.add_argument('--config', action="store", dest="config", type=str, default=None, required=True,
                    help="config to setup the training")

results = parser.parse_args()
config_file = results.config
sketch = results.sketch
with open(sketch, "rb") as f:
  race_pickle = pickle.load(f) # race sketch to be embellished with new results


assert(race_pickle["params"]["repetitions"] == 1)

nclasses = race_pickle["params"]["num_classes"]

with open(config_file, "r") as f:
  config = yaml.load(f) # config data

race = RaceGen(race_pickle["params"]) # do the change for Race as well
race.set_dictionary(race_pickle)
data = Data(config["data"])

data.reset()
topks = []
for i in range(nclasses):
    topk = race_pickle["memory"][i][0]['topK']
    hcols = [col for col in topk.columns if col.startswith('C')]
    topk =  topk.set_index(hcols)
    topk['actual_count'] = 0
    topks.append(topk)

COLS = [ i for i in range(len(hcols))]



data.reset()
num_samples = data.len()
batch_size = data.batch_size()
num_batches = int(np.ceil(num_samples/batch_size))
for i in tqdm(range(num_batches)):
    if data.end():
      break
    x, y = data.next()
    hashvalues = race.compute_hashes(x)
    for c in range(nclasses):
        hv = hashvalues[y==c]
        a = pd.DataFrame(np.array(hv))
        a.columns = hcols
        a['count'] = 1
        b = a.groupby(hcols).count()
        merged = pd.merge(topks[c], b, how='left', left_index=True, right_index=True)
        merged.fillna(0, inplace=True)
        merged['actual_count'] = merged['actual_count'] + merged['count']
        merged.drop(columns=['count'], inplace=True)
        topks[c] = merged



for i in range(nclasses):
    race_pickle["memory"][i][0]['topK'] = topks[i]

with open(sketch+".resketch", "wb") as f:
  pickle.dump(race_pickle, f) # race sketch to be embellished with new results
