import d3rlpy
import torch
import copy
import gym
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
from d3rlpy.metrics.scorer import evaluate_on_environment
from typing import Sequence
from abc import ABCMeta, abstractmethod
from d3rlpy.models.torch.encoders import Encoder
from d3rlpy.algos import BC, BCQ, BEAR, CQL
from d3rlpy.models.torch.encoders import VectorEncoderWithAction
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score  
from sklearn.preprocessing import StandardScaler 

class MyEncoder(VectorEncoderWithAction):
    def __init__(
        self,
        observation_shape: Sequence[int],
        hidden_units: Sequence[int],
        action_size: int,
        use_batch_norm: bool = False,
        dropout_rate: float = None,
        use_dense: bool = False,
        activation: nn.Module = nn.ReLU()
    ):
        super().__init__(
            observation_shape=observation_shape,
            action_size=action_size,
            hidden_units=hidden_units,
            use_batch_norm=use_batch_norm,
            dropout_rate=dropout_rate,
            use_dense=use_dense,
            activation=activation,
        )

    def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        if self._discrete_action:
            action = F.one_hot(
                action.view(-1).long(), num_classes=self.action_size
            ).float()
        x = torch.cat([x, action], dim=1)
        h = self._fc_encode(x)
        if self._use_batch_norm:
            h = self._bns[-1](h)
        if self._dropout_rate is not None:
            h = self._dropouts[-1](h)
        return h

    
def WCSS_kvalues():
    # dataset, env = d3rlpy.datasets.get_d4rl('walker2d-medium-v0')
    # dataset, env = d3rlpy.datasets.get_d4rl('hopper-medium-v0')
    dataset, env = d3rlpy.datasets.get_d4rl('halfcheetah-medium-v0')
    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)

    # encoder = MyEncoder(
    #     observation_shape=(17,),
    #     action_size=6, 
    #     hidden_units=[256, 256, 256],
    #     use_batch_norm=False,
    #     dropout_rate=None,
    #     use_dense=False,
    #     activation=nn.ReLU()
    # )

    # with torch.no_grad():
    #     path = "../walker2d_meduim_model_cql.pt"
    #     state_dict = torch.load(path)
    #     model = state_dict["_policy"]
        
    #     new_model = dict()
    #     for key in model.keys():
    #         new_key = key.replace("_q_funcs.0._encoder.""_q_funcs.1._encoder.""_q_funcs.2._encoder.", "")
    #         new_model[new_key] = model[key]
        
    #     encoder.load_state_dict(new_model, strict=False)
    #     encoder.eval()  
    #     print("success1")

    # encoded_ObservationsActions = []
    
    # for i, observations in enumerate(dataset.observations):
    #     observation_tensor = torch.Tensor(observations).unsqueeze(0)  
    #     action_tensor = torch.Tensor(dataset.actions[i]).unsqueeze(0) 
    #     encoded_observationaction = encoder(observation_tensor, action_tensor)
    #     encoded_observationaction = encoded_observationaction.detach().numpy()
    #     encoded_ObservationsActions.append(encoded_observationaction)
    # print("success2")

    # total_data_length = dataset.observations.shape[0]
    # segment_length = int(total_data_length * 0.2)
    # start_point = np.random.randint(0, total_data_length - segment_length) 
    # end_point = start_point + segment_length

    k_values = range(2, 20)  
    inertias = []  
    
    for k in k_values:  
        # kmeans = KMeans(n_clusters=k, random_state=42).fit(encoded_observations) 
        # kmeans = KMeans(n_clusters=k, random_state=42).fit(dataset.observations) 
        observations_float64 = np.array(dataset.observations).astype(np.float64) 
        actions_float64 = np.array(dataset.actions).astype(np.float64)
        ObservationsActions = np.hstack((observations_float64, actions_float64))
        ObservationsActions_array = np.vstack(ObservationsActions[start_point:end_point])
        kmeans = KMeans(n_clusters=k, random_state=42).fit(ObservationsActions_array) 
        inertias.append(kmeans.inertia_)  
    
  
    plt.plot(k_values, inertias, 'bx-')  
    plt.xlabel('k')  
    plt.ylabel('Inertia')  
    plt.title('Elbow Method For Optimal k')  
    plt.show()  
    best_k = k_values[inertias.index(min(inertias, key=lambda x: abs(x-np.mean(inertias[1:]))))]  
    print(f"The optimal number of clusters is {best_k}")
    

if __name__ == '__main__':

    WCSS_kvalues()

