
from Embedder import SentenceTransformerTextEmbedder
from KeywordList import KeywordList
# from Design import Design
from KeywordList import KeywordList
from DiversityMetric import DiversityMetric

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel, RBF, Matern, ExpSineSquared
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.metrics import pairwise_distances
import random
import json
import logging
import logging_config

class GPRSimulator:
    def __init__(
        self,
        embedding_type: str="sum_hint", # sum_hint, sum_obs
        reduce_feature_dim: int=None,
        warmup: int=100,
        window_size: int=2,
        kernel: str="dotproduct"
    ):
        self.embedder = SentenceTransformerTextEmbedder()

        if embedding_type not in ["sum_hint", "sum_obs"]:
            raise ValueError(f"Uknown embedding type: {embedding_type}")
        self.embedding_type = embedding_type

        self.reduce_feature_dim = reduce_feature_dim

        if kernel == "dotproduct":
            self.kernel = DotProduct() + WhiteKernel()
        elif kernel == "matern_nu2.5":
            self.kernel = Matern(nu=2.5) + WhiteKernel()
        elif kernel == "matern_nu1.5":
            self.kernel = Matern(nu=1.5) + WhiteKernel()
        elif kernel == "rbf":
            self.kernel = RBF() + WhiteKernel()
        else:
            raise ValueError(f"Unknown kernel {kernel}")

        self.warmup = warmup
        self.window_size = window_size


    def load_data(
        self,
        entries,
        m_hint_obs: dict = None
    ):
        if self.embedding_type  == "sum_obs":
            assert m_hint_obs != None


        base_entry_id = entries[0]["id"]
        end_entry_id = entries[-1]["id"]
        assert list(range(base_entry_id, end_entry_id + 1)) == [e["id"] for e in entries]

        # embedding_type == "sum_hint" or "sum_obs" 
        hint_list = sorted(list(m_hint_obs.keys()))
        m_hint_id = {h: i for i, h in enumerate(hint_list)}
        print(f"#hints = {len(hint_list)}")

        if self.embedding_type == "sum_hint":
            embedding_list = self.embedder.create_embedding(hint_list)
        else:
            obs_list = [m_hint_obs[h] for h in hint_list]
            embedding_list = self.embedder.create_embedding(obs_list)

        data_X = {
            e["id"]: np.sum(
                [
                    embedding_list[m_hint_id[h]]
                    for h in KeywordList(e["hints"]).keyword_list
                ],
                axis=0
            )
            for e in entries
        }
        print(f"feature_dim = {len(list(data_X.values())[0])}")
        assert len(data_X) == len(entries)
        return data_X

    def _predict(self, training_X, training_y, predicting_X):
        '''
        Return:
        - predicting_y: [n_predicting_samples, n_values]
        - predicting_std: [n_predicting_samples, n_values]
        '''
        logging.info(f"_predicting:\n\tn_training_samples = {len(training_X)}\n\tn_features = {len(training_X[0])}\n\tn_predicting_samples = {len(predicting_X)}\n\tn_values = {len(training_y[0])}")

        assert len(training_X) > 0
        assert len(predicting_X) > 0

        assert len(training_X) == len(training_y)
        assert all([len(x) == len(training_X[0]) for x in predicting_X])

        gpr = GaussianProcessRegressor(
            kernel=self.kernel,
            random_state=0
        ).fit(training_X, training_y)

        predicting_y, predicting_std = gpr.predict(predicting_X, return_std=True)
        
        assert len(predicting_y) == len(predicting_std) == len(predicting_X)
        assert all([len(y) == len(training_y[0]) for y in predicting_y + predicting_std])

        return predicting_y, predicting_std

    def _get_diversity(self, entries):
        return DiversityMetric.get_diversity(entries, 100)

    def _get_window_size(self, entries):
        if self.window_size != None:
            return self.window_size
        s = round(1 / (self._get_diversity(entries) / len(entries)))
        if s == 0:
            s = 1
        return s

    def _convert_to_vectors(
        self,
        training_entries,
        predicting_entries,
        data_X
    ):
        training_entry_filter = lambda e: DiversityMetric.check_feedback_embedding(e) == True and e["id"] in data_X
        training_X = [
            data_X[e["id"]]
            for e in training_entries
            if training_entry_filter(e) == True
        ]
        logging.info(f"training_entries: {len(training_entries)} -> {len(training_X)}")

        training_y = [
            [x * 100 for x in e["feedback_embedding"]]
            for e in training_entries
            if training_entry_filter(e) == True
        ]
        assert len(training_y) == len(training_X)

        assert all([e["id"] in data_X for e in predicting_entries])
        predicting_X = [
            data_X[e["id"]]
            for e in predicting_entries
        ]
        logging.info(f"predicting_entries: {len(predicting_entries)}")

        if self.reduce_feature_dim == None:
            return training_X, training_y, predicting_X

        normalized_X = StandardScaler().fit_transform(training_X + predicting_X)
        reduced_X = PCA(n_components=self.reduce_feature_dim, random_state=0).fit_transform(normalized_X)
        reduced_training_X = [v for v in reduced_X[:len(training_X)]]
        reduced_predicting_X = [v for v in reduced_X[len(training_X):]]
        assert len(reduced_training_X) == len(training_X)
        assert len(reduced_predicting_X) == len(predicting_X)
        assert all([len(v) == self.reduce_feature_dim for v in reduced_training_X + reduced_predicting_X])
        return reduced_training_X, training_y, reduced_predicting_X

    def _top1_idx(self, training_y, predicting_y, explore_exploit):
        assert len(training_y) > 0
        assert len(predicting_y) > 0
        if explore_exploit == "explore":
            dist_matrix = pairwise_distances(
                X=predicting_y,
                Y=training_y,
                metric="euclidean"
            )
            assert len(dist_matrix) == len(predicting_y)

            diversity_list = dist_matrix.min(axis=1)
            assert len(diversity_list) == len(predicting_y)

            return np.argsort(diversity_list)[-1] # the most diverse one
        
        assert explore_exploit == "exploit"
        quality_list = [
            float(np.mean(y))
            for y in predicting_y
        ] # this is miss ratios, so we need to choose the minimum
        assert len(quality_list) == len(predicting_y)
        return np.argsort(quality_list)[0] # the one with the minum average miss ratio



    # def simulate(
    #     self,
    #     entries,
    #     m_hint_obs,
    # ):
    #     data_X = self.load_data(entries, m_hint_obs)

    #     simulated_entries = [e for e in entries[:self.warmup]]

    #     cur_entry_idx = self.warmup
    #     cur_window_size = self._get_window_size(simulated_entries)
    #     cur_mode = "explore"

    #     batch_size = self.warmup
        
    #     while cur_entry_idx < len(entries):
    #         if (len(simulated_entries) % batch_size) > (batch_size * 0.5):
    #             cur_mode = "exploit"
    #         else:
    #             cur_mode = "explore"

    #         logging.info(f"cur_mode = {cur_mode}")

    #         if cur_window_size == 1:
    #             logging.info("cur_window_size = 1")
    #             simulated_entries.append(entries[cur_entry_idx])
    #             cur_entry_idx += 1
    #             cur_window_size = self._get_window_size(simulated_entries)
    #             continue
                
    #         predicting_entries = []
    #         while len(predicting_entries) < cur_window_size and cur_entry_idx < len(entries):
    #             cur_entry = entries[cur_entry_idx]
    #             if cur_entry["id"] in data_X:
    #                 predicting_entries.append(cur_entry)
    #             cur_entry_idx += 1
            
    #         if len(predicting_entries) < cur_window_size:
    #             assert cur_entry_idx == len(entries)
    #             if len(predicting_entries) == 1:
    #                 break
    #         else:
    #             assert len(predicting_entries) == cur_window_size
    #         assert len(predicting_entries) >= 2

    #         training_X, training_y, predicting_X = self._convert_to_vectors(
    #             training_entries=simulated_entries,
    #             predicting_entries=predicting_entries,
    #             data_X=data_X
    #         )

    #         predicting_y, predicting_std = self._predict(
    #             training_X=training_X,
    #             training_y=training_y,
    #             predicting_X=predicting_X
    #         )
    #         assert len(predicting_y) == len(predicting_entries)

    #         # select the top1 entry
    #         top1_idx = self._top1_idx(
    #             training_y=training_y,
    #             predicting_y=predicting_y,
    #             explore_exploit=cur_mode,
    #         )
    #         assert 0 <= top1_idx < len(predicting_entries)
    #         top1_entry = predicting_entries[top1_idx]
    #         simulated_entries.append(top1_entry)
    #         cur_window_size = self._get_window_size(simulated_entries)

    #     return simulated_entries
