import pandas as pd
import numpy as np
import random

import math
from utils import PROBLEM_FEATURES, ACTIONS, PROBLEM_LIST
from tensorflow.keras.models import load_model
import tensorflow as tf

from math import nan
import pickle
import csv
import os 
from scipy.stats import beta
from absl import app

# save data by subgroup

# some parameter
TE_START_IDX = 1148
TRAIN_SIZE = 1147
C = 5
RATIO = 1
ENV_NAME = 'ITS'

def main(_):
    # check test init clusters
    fname = '../cluster_data/ITS.txt'
    cluster = []
    with open(fname, 'r') as fd:
        reader = csv.reader(fd)
        for row in reader:
            cluster+=row
    cluster = [int(i) for i in cluster]
    test_cluster = cluster[TE_START_IDX:]

    cluster_num = {c:0 for c in range(C)}
    for c in range(C):
        cluster_num[c] = test_cluster.count(c)/len(test_cluster)


    # realigned_data_dict = {}
    for c in range(C):
        # c = 0 # check 1 cluster
        ENV_INFO = str(ENV_NAME)+'_cluster_'+str(c)

        # check which trained model to use depending on training performance
        augment_infos = ["./rl_stats/"+i for i in os.listdir("./rl_stats/") if ENV_INFO in i]
        best_elbo = -1000000.
        best_info = ''
        for info in augment_infos:
            with open(info) as f:
                for line in f:
                    pass
                last_line = line
            locat_ = last_line.index('ELBO: ')
            info_elbo = float(last_line[locat_+6:locat_+12])
        #     print(info_elbo)
            if info_elbo > best_elbo:
                best_elbo = info_elbo
                best_info = info

        with open('./saved_augmented_data/'+best_info[11:-4]+'_augmented_segment.npy', 'rb') as f:
            augmented_segments = np.load(f, allow_pickle=True)

        # convert action to discrete
        for s in augmented_segments:
            for idx,actions in enumerate(s['actions']):
                max_a = max(list(actions))
                s['actions'][idx] = [1 if a == max_a else 0 for a in list(actions)]

        # for each cluster: load training data
        train_cluster = cluster[:TE_START_IDX]
        with open('../processed_data/{}/train_cluster_{}.npy'.format(ENV_NAME, c), 'rb') as f: 
            train_data = np.load(f, allow_pickle=True)

        # decide the mix-up of augmented and training data
        NUM_NEED = TRAIN_SIZE

        num_train_need = int(NUM_NEED * beta.rvs(2, 2, size=1)[0])
        while len(train_data) < num_train_need:
            num_train_need = int(NUM_NEED * beta.rvs(2, 2, size=1)[0])

        num_aug_need = int(NUM_NEED - num_train_need)

        # random select from data
        sel_train_data = random.choices(train_data, k=num_train_need)
        sel_aug_data = random.choices(augmented_segments, k=num_aug_need)
        print("cluster:{}, train num:{}, aug num:{}".format(c, num_train_need, num_aug_need))

        realigned_data_all = sel_train_data+sel_aug_data

        new_keys = ['userID', 'inferred_rew']+PROBLEM_FEATURES+ACTIONS
        new_data = dict.fromkeys(new_keys, [])


        for idx in range(len(realigned_data_all)):

        #     print(len(new_data['userID'] ))
            user_data = realigned_data_all[idx]

            new_data['userID'] = new_data['userID']+[100+idx for i in range(len(user_data['observations']))] # start from 100
            new_data['inferred_rew'] = new_data['inferred_rew']+[i for i in user_data['rewards']]
            for f in range(len(PROBLEM_FEATURES)):
                new_data[PROBLEM_FEATURES[f]] = new_data[PROBLEM_FEATURES[f]] + [i[f] for i in user_data['observations']]
            for a in range(len(ACTIONS)):
                new_data[ACTIONS[a]] = new_data[ACTIONS[a]] + [i[a] for i in user_data['actions']]

        df = pd.DataFrame.from_dict(new_data)

        df.to_csv('../augmented_dataset/ITS_subgroup/c_{}.csv'.format(c), index=False)

if __name__ == '__main__':
    app.run(main)