import d3rlpy
import torch
import copy
import gym
import random
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from d3rlpy.metrics.scorer import evaluate_on_environment
from typing import Sequence
from abc import ABCMeta, abstractmethod
from d3rlpy.models.torch.encoders import Encoder
from d3rlpy.algos import BC, BCQ, BEAR, CQL
from d3rlpy.models.torch.encoders import VectorEncoderWithAction

def assign_episode_to_dataset(dataset):
    episode_ids = []
    current_episode_id = 0
    for terminal in dataset.terminals:
        episode_ids.append(current_episode_id)
        if terminal:
            current_episode_id += 1
    return np.array(episode_ids)

class MyEncoder(VectorEncoderWithAction):
    def __init__(
        self,
        observation_shape: Sequence[int],
        hidden_units: Sequence[int],
        action_size: int,
        use_batch_norm: bool = False,
        dropout_rate: float = None,
        use_dense: bool = False,
        activation: nn.Module = nn.ReLU()
    ):
        super().__init__(
            observation_shape=observation_shape,
            action_size=action_size,
            hidden_units=hidden_units,
            use_batch_norm=use_batch_norm,
            dropout_rate=dropout_rate,
            use_dense=use_dense,
            activation=activation,
        )

    def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        if self._discrete_action:
            action = F.one_hot(
                action.view(-1).long(), num_classes=self.action_size
            ).float()
        x = torch.cat([x, action], dim=1)
        h = self._fc_encode(x)
        if self._use_batch_norm:
            h = self._bns[-1](h)
        if self._dropout_rate is not None:
            h = self._dropouts[-1](h)
        return h

def remove_adjacent_duplicate_labels(labels):
    unique_labels = [labels[0]]
    for label in labels[1:]:
        if label != unique_labels[-1]:
            unique_labels.append(label)
    return unique_labels

def calculate_label_combination_counts(cluster_labels, time_slice_length):
    label_combinations_counts = {}
    for i in range(len(cluster_labels) - time_slice_length + 1):
        time_slice = cluster_labels[i:i + time_slice_length]
        unique_time_slice = tuple(remove_adjacent_duplicate_labels(time_slice))
        label_combinations_counts[unique_time_slice] = label_combinations_counts.get(unique_time_slice, 0) + 1
    return label_combinations_counts

def select_replacement_index(label_combinations_counts, poisoned_clusters_list, slice_start_idx, time_slice_length):
    max_count = -1
    best_poisoned_idx = None
    for idx, poisoned_clusters in enumerate(poisoned_clusters_list):
        unique_poisoned_clusters = remove_adjacent_duplicate_labels(poisoned_clusters)
        poisoned_label_combination = tuple(unique_poisoned_clusters)
        poisoned_count = label_combinations_counts.get(poisoned_label_combination, 0)
        
        if poisoned_count > max_count:
            max_count = poisoned_count
            best_poisoned_idx = idx

    return best_poisoned_idx

def poison():
    dataset, env = d3rlpy.datasets.get_d4rl('walker2d-medium-v0')
    # dataset, env = d3rlpy.datasets.get_d4rl('hopper-medium-v0')
    # dataset, env = d3rlpy.datasets.get_d4rl('halfcheetah-medium-v0')
    scorer = evaluate_on_environment(env)

    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)

    episodes = assign_episode_to_dataset(dataset)

    encoder = MyEncoder(
        observation_shape=(17,),
        action_size=6, 
        hidden_units=[256, 256, 256],
        use_batch_norm=False,
        dropout_rate=None,
        use_dense=False,
        activation=nn.ReLU()
    )

    with torch.no_grad():
        path = "../walker2d_meduim_model_cql.pt"
        state_dict = torch.load(path)
        model = state_dict["_policy"]
        
        new_model = dict()
        for key in model.keys():
            new_key = key.replace("_q_funcs.0._encoder.""_q_funcs.1._encoder.""_q_funcs.2._encoder.", "")
            new_model[new_key] = model[key]
        
        encoder.load_state_dict(new_model, strict=False)
        encoder.eval()  
        print("success1")

    encoded_ObservationsActions = []
    cluster_labels = []

    for i, observations in enumerate(dataset.observations):
        observation_tensor = torch.Tensor(observations).unsqueeze(0)  
        action_tensor = torch.Tensor(dataset.actions[i]).unsqueeze(0)  
        encoded_observationaction = encoder(observation_tensor, action_tensor) 
        encoded_observationaction = encoded_observationaction.detach().numpy() 
        encoded_ObservationsActions.append(encoded_observationaction)
    print("success2")

    k = 8  
    max_perturbation_ratio = 0.05 
    time_slice_length = 5

    encoded_ObservationsActions_array = np.vstack(encoded_ObservationsActions)
    assert encoded_ObservationsActions_array.ndim == 2, "Features array must be 2D for k-means clustering."

    kmeans = KMeans(n_clusters=k, random_state=42)
    cluster_labels = kmeans.fit_predict(encoded_ObservationsActions_array)
    print(('success3'))

    results = []

    for i in range(len(cluster_labels) - time_slice_length + 1):
        current_slice_labels = cluster_labels[i:i + time_slice_length]
        unique_slice_labels = [current_slice_labels[0]]
        for label in current_slice_labels[1:]:
            if label != unique_slice_labels[-1]:
                unique_slice_labels.append(label)
        
        coverage = len(set(unique_slice_labels)) / time_slice_length
        
        results.append({
            'Starting Point': i,
            'Label Combination': ''.join([str(label) for label in unique_slice_labels]),
            'Count': 1,  
            'Coverage': coverage
        })

    label_combinations = calculate_label_combination_counts(cluster_labels, time_slice_length)
    results_df = pd.DataFrame(results)
    grouped = results_df.groupby(['Label Combination', 'Coverage', 'Starting Point'])
    count_df = grouped.size().reset_index(name='Count')
    sorted_df = count_df.sort_values(by=['Count','Coverage'], ascending=[True, False]) 
    PTarget = np.zeros(len(cluster_labels), dtype=int)
    selected_indices = set()
    selected_data_points = 0 
    selected_time_steps = []
    
    for _, row in sorted_df.iterrows():
        start_idx = row['Starting Point'] 
        end_idx = start_idx + time_slice_length
        time_steps = list(range(start_idx, end_idx))
        if not any(ts in selected_time_steps for ts in time_steps):
            selected_time_steps.extend(time_steps)
            for ts in time_steps:
                PTarget[ts] = 1
            selected_data_points += time_slice_length
            if len(selected_time_steps) >= 50000:
                break
    for slice_start_idx in selected_time_steps:
        poison_start_idx = random.randint(0, 5)
        poisoned_data_list = []
        poisoned_features_list = []
        poisoned_clusters_list = []
        for i in range(50):
            poisoned_observations = []
            poisoned_actions = []
            for j in range(poison_start_idx, poison_start_idx + 5):
                PTarget[slice_start_idx + j] = 1
                current_idx = slice_start_idx + j
                current_observation = dataset.observations[current_idx]
                current_action = dataset.actions[current_idx]
                noise_magnitude_obs = np.abs(current_observation) * max_perturbation_ratio
                noise_magnitude_act = np.abs(current_action) * max_perturbation_ratio
                random_noise_obs = np.random.uniform(-1, 1, current_observation.shape) * noise_magnitude_obs
                random_noise_act = np.random.uniform(-1, 1, current_action.shape) * noise_magnitude_act
                poisoned_observation = current_observation + random_noise_obs
                poisoned_action = current_action + random_noise_act
                poisoned_observations.append(poisoned_observation)
                poisoned_actions.append(poisoned_action)
            poisoned_data_list.append((poisoned_observations, poisoned_actions))

            encoded_features = []
            for observations, actions in zip(poisoned_observations, poisoned_actions):
                observations = torch.Tensor(observations).unsqueeze(0)
                actions = torch.Tensor(actions).unsqueeze(0)
                encoded_feat = encoder(observations, actions).detach().numpy()
                encoded_features.append(encoded_feat.ravel()) 

            encoded_features_array = np.vstack(encoded_features)
            poisoned_clusters = kmeans.predict(encoded_features_array)
            poisoned_features_list.append(encoded_features_array)
            poisoned_clusters_list.append(poisoned_clusters)
        best_poisoned_idx = select_replacement_index(label_combinations, poisoned_clusters_list, slice_start_idx, time_slice_length)
            
        if best_poisoned_idx is not None:
            best_poisoned_observations, best_poisoned_actions = poisoned_data_list[best_poisoned_idx]
            for inner_idx, poisoned_obs in zip(range(slice_start_idx, slice_start_idx + time_slice_length), best_poisoned_observations):
                dataset.observations[inner_idx] = poisoned_obs
            for inner_idx, poisoned_act in zip(range(slice_start_idx, slice_start_idx + time_slice_length), best_poisoned_actions):
                dataset.actions[inner_idx] = poisoned_act
    return dataset

if __name__ == '__main__':

    poison()