import pandas as pd
import numpy as np
from absl import app
from absl import flags
# from TICC_solver import MTTICC
from pygapbide import *
import random
from math import nan
import csv
import os
os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'


import multiprocessing as mp
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_samples, silhouette_score

import gc
import json
import pickle
from absl import app
from absl import flags
from absl import logging
import gym
from gym import wrappers
from gym.wrappers import time_limit
import d4rl  # pylint: disable=unused-import

import numpy as np
import tensorflow as tf
from tf_agents.environments import gym_wrapper
from tf_agents.environments import suite_mujoco
from tf_agents.environments import tf_py_environment
import tqdm
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
physical_devices = tf.config.list_physical_devices('GPU')
try:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  # Invalid device or cannot modify virtual devices once initialized.
  pass

EPS = np.finfo(np.float32).eps
FLAGS = flags.FLAGS

flags.DEFINE_string('env_name', 'pen-human-v1',
                    'Environment for augmentation.')
flags.DEFINE_integer('seed', 0, 'random seed.')
flags.DEFINE_integer('min_c', 10, 'minimum value for searching cluster number.')
flags.DEFINE_integer('max_c', 20, 'maximum value for searching cluster number.')
flags.DEFINE_integer('min_sup', 10, 'threshold for support.')


def find_sub_list(sl,l):
    sll=len(sl)
    for ind in (i for i,e in enumerate(l) if e==sl[0]):
        if l[ind:ind+sll]==sl:
            return ind,ind+sll
    return -1,-1

def main(_):
	np.random.seed(FLAGS.seed)
	
	# load data
	env = gym.make(FLAGS.env_name)
	d4rl_original_data = [i for i in d4rl.sequence_dataset(env)]

	# get # of clusters
	X = [r for i in range(len(d4rl_original_data)) for r in d4rl_original_data[i]['observations'].tolist()] # data
	range_n_clusters = range(FLAGS.min_c, FLAGS.max_c+1)
	sil_score = [0.]*(FLAGS.max_c-FLAGS.min_c+1)

	def getSilScore(n_clusters):
	    clusterer = KMeans(n_clusters=n_clusters, random_state=10)
	    cluster_labels = clusterer.fit_predict(X)
	    silhouette_avg = silhouette_score(X, cluster_labels)

	    return [n_clusters,silhouette_avg]

	pool = mp.Pool(5)
	sil_score = pool.map(getSilScore, [n_clusters for n_clusters in range_n_clusters])
	pool.close()
	pool.join()    
	    
	num_clusters = sil_score[[i[1] for i in sil_score].index(max([el[1] for el in sil_score]))][0]
	print('selected num of clusters: {} !'.format(num_clusters))

	# mtticc
	# preprocess data as .data first
	data = []
	intervals = []

	for v in range(len(d4rl_original_data)):
	    
	    v_data = d4rl_original_data[v]['observations'].tolist()
	    data.append(v_data)
	    
	    v_intervals = []
	    v_intervals.append(nan)
	    v_intervals = v_intervals+[1. for i in range(len(v_data)-1)]   
	    intervals.append(v_intervals)
	    
	raw_save_path = './raw_data/data.data'
	pickle_out = open(raw_save_path,'wb')
	pickle.dump([data,intervals], pickle_out)
	pickle_out.close()

	# # start mtticc
	# mtticc = MTTICC(fixed_window=1, number_of_clusters=num_clusters, lambda_parameter=11e-3,
	#                 beta=1, maxIters=100, num_proc=6, input_pattern='multiple', window_pattern='fixed')
	# TICC_fname = raw_save_path
	# TICC_return_cv = mtticc.fit(TICC_fname)
	# pred_cluster_cv = TICC_return_cv[1]

	# print('MTTICC assignment done!')
	# cluster_list = []
	# for i in pred_cluster_cv:
	#     cluster_list = cluster_list+i

	# np.savetxt('./raw_data/mtticc_clusters.txt', cluster_list, fmt='%d', delimiter=',')

	# read cluster index by samples
	fname = './raw_data/mtticc_clusters.txt'
	cluster = []
	with open(fname, 'r') as fd:
	    reader = csv.reader(fd)
	    for row in reader:
	        cluster+=row
	cluster = [int(i) for i in cluster]

	cluster_db = []
	p = 0
	for i in range(len(d4rl_original_data)):
	    i_len = len(d4rl_original_data[i]['observations'])
	    cluster_db.append(cluster[p:(p+i_len)])
	    p += i_len
	    
	## get patterns
	# min-support can be set as 1, but we keep it as 10 for fast tracing, which won't affect the result for topk selection 
	g = Gapbide(cluster_db, FLAGS.min_sup, 0, 0) # (min-support, gap)
	raw_pattern = g.run()

	# here, we directly search the one which support = N, i.e., top1
	# this can be extended to iteratively searching procedure for easier blackbox usage
	for i in raw_pattern:
		if i[1] == len(d4rl_original_data):
			pattern = i[0]
			break
	seg = []
	# given a temporal discrete sub-sequence, find all PSTs from each traj
	for user in range(len(d4rl_original_data)):
	    # initialize local segment
	    user_seg = dict.fromkeys(['observations','actions','rewards','next_observations','terminals'])
	    begin, end = -1,-1
	    
	    begin, end = find_sub_list(pattern, cluster_db[user][:-1])
	    if begin != -1:
	        user_seg['observations'] = d4rl_original_data[user]['observations'][begin:end]
	        user_seg['actions'] = d4rl_original_data[user]['actions'][begin:end,:] #2D
	        user_seg['rewards'] = d4rl_original_data[user]['rewards'][begin:end]
	        if end != len(d4rl_original_data[user]['observations']):
	            user_seg['next_observations'] = d4rl_original_data[user]['observations'][(begin+1):(end+1)]
	        else:
	            user_seg['next_observations'] = [-1]*len(d4rl_original_data[user]['observations'][0])
	        user_seg['terminals'] = np.asarray([False if i < len(pattern)-1 else True for i in range(len(pattern))])
	        
	        seg.append(user_seg)

	with open('./processed_data/train_pattern.npy', 'wb') as f:
		np.save(f, seg)
