import d3rlpy
import torch
import copy
import gym
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from d3rlpy.metrics.scorer import evaluate_on_environment
from typing import Sequence
from abc import ABCMeta, abstractmethod
from d3rlpy.models.torch.encoders import Encoder
from d3rlpy.algos import BC, BCQ, BEAR, CQL
from d3rlpy.models.torch.encoders import VectorEncoderWithAction

def assign_episode_to_dataset(dataset):
    episode_ids = []
    current_episode_id = 0
    for terminal in dataset.terminals:
        episode_ids.append(current_episode_id)
        if terminal:
            current_episode_id += 1
    return np.array(episode_ids)

class MyEncoder(VectorEncoderWithAction):
    def __init__(
        self,
        observation_shape: Sequence[int],
        hidden_units: Sequence[int],
        action_size: int,
        use_batch_norm: bool = False,
        dropout_rate: float = None,
        use_dense: bool = False,
        activation: nn.Module = nn.ReLU()
    ):
        super().__init__(
            observation_shape=observation_shape,
            action_size=action_size,
            hidden_units=hidden_units,
            use_batch_norm=use_batch_norm,
            dropout_rate=dropout_rate,
            use_dense=use_dense,
            activation=activation,
        )

    def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        if self._discrete_action:
            action = F.one_hot(
                action.view(-1).long(), num_classes=self.action_size
            ).float()
        x = torch.cat([x, action], dim=1)
        h = self._fc_encode(x)
        if self._use_batch_norm:
            h = self._bns[-1](h)
        if self._dropout_rate is not None:
            h = self._dropouts[-1](h)
        return h

def observation_addperturb(observations, max_perturbation_ratio):
    perturbation_factor = np.random.uniform(low=-1, high=1, size=observations.shape)
    perturbation = perturbation_factor / np.linalg.norm(perturbation_factor, ord=np.inf) * max_perturbation_ratio * observations
    perturbed_observation = observations + perturbation
    return perturbed_observation

def action_addperturb(actions, max_perturbation_ratio):
    perturbation_factor = np.random.uniform(low=-1, high=1, size=actions.shape)
    perturbation = perturbation_factor / np.linalg.norm(perturbation_factor, ord=np.inf) * max_perturbation_ratio * actions
    perturbed_actions = actions + perturbation
    perturbed_actions = np.clip(perturbed_actions, -1, 1)
    return perturbed_actions


def poison():
    dataset, env = d3rlpy.datasets.get_d4rl('walker2d-medium-v0')
    # dataset, env = d3rlpy.datasets.get_d4rl('hopper-medium-v0')
    # dataset, env = d3rlpy.datasets.get_d4rl('halfcheetah-medium-v0')
    scorer = evaluate_on_environment(env)

    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)

    episodes = assign_episode_to_dataset(dataset)

    encoder = MyEncoder(
        observation_shape=(17,), 
        action_size=6, 
        hidden_units=[256, 256, 256],
        use_batch_norm=False,
        dropout_rate=None,
        use_dense=False,
        activation=nn.ReLU()
    )

    with torch.no_grad():
        path = "../walker2d_meduim_model_cql.pt"
        # path = "../hopper_meduim_model_cql.pt"
        # path = "../half_meduim_model_cql.pt"
        
        state_dict = torch.load(path)
        model = state_dict["_policy"]
        
        new_model = dict()
        for key in model.keys():
            new_key = key.replace("_q_funcs.0._encoder.""_q_funcs.1._encoder.""_q_funcs.2._encoder.", "")
            new_model[new_key] = model[key]
        
        encoder.load_state_dict(new_model, strict=False)
        encoder.eval()  
        print("success1")

    encoded_ObservationsActions = []
    cluster_labels = []

    for i, observations in enumerate(dataset.observations):
        observation_tensor = torch.Tensor(observations).unsqueeze(0) 
        action_tensor = torch.Tensor(dataset.actions[i]).unsqueeze(0) 
        encoded_observationaction = encoder(observation_tensor, action_tensor) 
        encoded_observationaction = encoded_observationaction.detach().numpy()  
        encoded_ObservationsActions.append(encoded_observationaction)
    print("success2")

    k = 8  
    max_perturbation_ratio = 0.05
    time_slice_length = 5

    encoded_ObservationsActions_array = np.vstack(encoded_ObservationsActions)
    assert encoded_ObservationsActions_array.ndim == 2, "Features array must be 2D for k-means clustering."

    kmeans = KMeans(n_clusters=k, random_state=42)
    cluster_labels = kmeans.fit_predict(encoded_ObservationsActions_array)
    print(('success3'))

    results = []

    for i in range(len(cluster_labels) - time_slice_length + 1):
        current_slice_labels = cluster_labels[i:i + time_slice_length]
        unique_slice_labels = [current_slice_labels[0]]
        for label in current_slice_labels[1:]:
            if label != unique_slice_labels[-1]:
                unique_slice_labels.append(label)
        
        coverage = len(set(unique_slice_labels)) / k
        
        results.append({
            'Starting Point': i,
            'Label Combination': ''.join([str(label) for label in unique_slice_labels]),
            'Count': 1,  
            'Coverage': coverage
        })

    results_df = pd.DataFrame(results)
    results_df = results_df.groupby(['Label Combination', 'Coverage'])['Starting Point'].count().reset_index()
    results_df.rename(columns={'Starting Point': 'Count'}, inplace=True)
    results_df.to_csv('time_slice_coverage.csv', index=False)
    results_df = results_df.groupby(['Label Combination', 'Coverage']).agg({'Count': 'size'}).reset_index()

    sorted_df = results_df.sort_values(by='Count', ascending=True)

    selected_points = []
    selected_indexes = []
    PTarget = np.zeros(len(cluster_labels))
    max_count = 10000 // time_slice_length  

    for _, row in sorted_df.iterrows():
        count = row['Count']
        label_combination = row['Label Combination']
        candidate_indexes = [i for i, result in enumerate(results) if result['Label Combination'] == label_combination]
        non_overlapping_indexes = [i for i in candidate_indexes if all([(i+j) not in selected_points for j in range(time_slice_length)])]
        selected_count = min(count, max_count, len(non_overlapping_indexes))
        for i in range(selected_count):
            index = non_overlapping_indexes[i] 
            selected_points.extend([index + j for j in range(time_slice_length)])
            PTarget[i:i + time_slice_length] = 1
            selected_indexes.append(index)
        max_count -= selected_count
        
        if max_count <= 0:
            break

    assert len(selected_points) == 10000

    for index in selected_indexes:
        start, end = index, index + time_slice_length
        for idx in range(start, end):
            dataset.observations[idx] = observation_addperturb(dataset.observations[idx], 0.05)
            dataset.actions[idx] = action_addperturb(dataset.actions[idx], 0.05)
    return dataset

if __name__ == '__main__':

    poison()
