import pickle
import os
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from utils.motsc.buffer import buffer

class DistributionTester():
    def __init__(self, path_to_offline_dataset):
        dataset = []
        files = os.listdir(path_to_offline_dataset)
        for file in files:
            with open(os.path.join(path_to_offline_dataset, file), "rb") as f:
                dataset.append(pickle.load(f))

        dic_traffic_env_conf = {
            "LIST_STATE_FEATURE": [
                "lane_num_vehicle",
                "lane_num_waiting_vehicle_in",
                "traffic_movement_pressure_queue",
                "traffic_movement_pressure_queue_efficient",
                "lane_enter_running_part",
            ],
            "DIC_REWARD_INFO": {
                # "lane_num_waiting_vehicle_in": -0.25,
                "traffic_movement_pressure_queue_efficient": -0.25,
            },
            "PHASE": {
                1: [0, 1, 0, 1, 0, 0, 0, 0],
                2: [0, 0, 0, 0, 0, 1, 0, 1],
                3: [1, 0, 1, 0, 0, 0, 0, 0],
                4: [0, 0, 0, 0, 1, 0, 1, 0]
            },
            "NUM_LANE": 12
        }
        self.chosen_features = None
        self.data_buffer = buffer(dataset, dic_traffic_env_conf)
        self.t_sne_result = None
        
    def t_sne(self, features):
        self.chosen_features = features
        raw_data = self.data_buffer.sample_all_chosen_features(features)
        t_sne = TSNE(n_components=2, init="pca")
        self.t_sne_result = t_sne.fit_transform(raw_data)
        print("[INFO] Fitted from dim {} to dim {}".format(raw_data.shape[-1], self.t_sne_result.shape[-1]))
        
    def draw_distribution(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        result_min, result_max = self.t_sne_result.min(0), self.t_sne_result.max(0)
        result_norm = (self.t_sne_result - result_min) / (result_max - result_min)  # normalize
        plt.figure(figsize=(8, 8))
        plt.scatter(result_norm[:, 0], result_norm[:, 1])
        for index, feature in enumerate(self.chosen_features):
            plt.text(0, 0.05 * index, feature)

        files = os.listdir(save_dir)
        num_files = 0
        for file in files:
            if "state_distribute_{}_feat".format(len(self.chosen_features)) in file:
                num_files += 1
        plt.savefig(os.path.join(save_dir, "state_distribute_{}_feat_{}.jpg".format(len(self.chosen_features), 
                                                                                    num_files)))