import d3rlpy
import torch
import copy
import gym
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from d3rlpy.metrics.scorer import evaluate_on_environment
from typing import Sequence
from abc import ABCMeta, abstractmethod
from d3rlpy.models.torch.encoders import Encoder
from d3rlpy.algos import BC, BCQ, BEAR, CQL
from d3rlpy.models.torch.encoders import _VectorEncoder

def assign_episode_to_dataset(dataset):
    episode_ids = []
    current_episode_id = 0
    for terminal in dataset.terminals:
        episode_ids.append(current_episode_id)
        if terminal:
            current_episode_id += 1
    return np.array(episode_ids)

def remove_adjacent_duplicate_labels(labels):
    unique_labels = [labels[0]]
    for label in labels[1:]:
        if label != unique_labels[-1]:
            unique_labels.append(label)
    return unique_labels

def calculate_label_combination_counts(cluster_labels, time_slice_length):
    label_combinations_counts = {}
    for i in range(len(cluster_labels) - time_slice_length + 1):
        time_slice = cluster_labels[i:i + time_slice_length]
        unique_time_slice = tuple(remove_adjacent_duplicate_labels(time_slice))
        label_combinations_counts[unique_time_slice] = label_combinations_counts.get(unique_time_slice, 0) + 1
    return label_combinations_counts

def select_replacement_index(label_combinations_counts, poisoned_clusters_list, slice_start_idx, time_slice_length):
    max_count = -1
    best_poisoned_idx = None
    
    for idx, poisoned_clusters in enumerate(poisoned_clusters_list):
        unique_poisoned_clusters = remove_adjacent_duplicate_labels(poisoned_clusters)
        poisoned_label_combination = tuple(unique_poisoned_clusters)
        poisoned_count = label_combinations_counts.get(poisoned_label_combination, 0)
        if poisoned_count > max_count:
            max_count = poisoned_count
            best_poisoned_idx = idx
    return best_poisoned_idx

def poison():
    dataset, env = d3rlpy.datasets.get_d4rl('walker2d-medium-v0')
    # dataset, env = d3rlpy.datasets.get_d4rl('hopper-medium-v0')
    # dataset, env = d3rlpy.datasets.get_d4rl('halfcheetah-medium-v0')
    scorer = evaluate_on_environment(env)

    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)

    episodes = assign_episode_to_dataset(dataset)

    k = 9 
    max_perturbation_ratio = 0.05
    time_slice_length = 9

    observations_float64 = np.array(dataset.observations).astype(np.float64) 
    actions_float64 = np.array(dataset.actions).astype(np.float64)
    ObservationsActions = np.hstack((observations_float64, actions_float64))
    total_data_length = dataset.observations.shape[0] 
    segment_length = int(total_data_length * 0.01) 
    start_point = np.random.randint(0, total_data_length - segment_length)  
    end_point = start_point + segment_length

    kmeans = KMeans(n_clusters=k, random_state=42)
    cluster_labels = kmeans.fit_predict(ObservationsActions[start_point:end_point])
    results = []

    for i in range(len(cluster_labels) - time_slice_length + 1):
        current_slice_labels = cluster_labels[i:i + time_slice_length]
        unique_slice_labels = [current_slice_labels[0]]
        for label in current_slice_labels[1:]:
            if label != unique_slice_labels[-1]:
                unique_slice_labels.append(label)
        coverage = len(set(unique_slice_labels)) / time_slice_length
        results.append({
            'Starting Point': i,
            'Label Combination': ''.join([str(label) for label in unique_slice_labels]),
            'Count': 1,  
            'Coverage': coverage
        })
    label_combinations = calculate_label_combination_counts(cluster_labels, time_slice_length)
    results_df = pd.DataFrame(results)
    grouped = results_df.groupby(['Label Combination', 'Coverage', 'Starting Point'])
    count_df = grouped.size().reset_index(name='Count')
    sorted_df = count_df.sort_values(by=['Count','Coverage'], ascending=[True, False]) 
    PTarget = np.zeros(len(cluster_labels), dtype=int)
    selected_indices = set()
    selected_data_points = 0 
    selected_time_steps = []
    
    for _, row in sorted_df.iterrows():
        start_idx = row['Starting Point'] 
        end_idx = start_idx + time_slice_length
        time_steps = list(range(start_idx, end_idx))
        if not any(ts in selected_time_steps for ts in time_steps):
            selected_time_steps.extend(time_steps)
            for ts in time_steps:
                PTarget[ts] = 1
            selected_data_points += time_slice_length
            if len(selected_time_steps) >= 10000:
                break
    coverage_values = np.array([[result['Coverage'], time_slice_length] for result in results])

    for slice_start_idx in selected_time_steps:   
        poisoned_data_list = []
        for i in range(50): 
            poisoned_observations = []
            poisoned_actions = []
            for j in range(time_slice_length):
                current_idx = slice_start_idx + j
                current_observation = dataset.observations[current_idx]
                current_action = dataset.actions[current_idx]
                noise_magnitude_obs = np.abs(current_observation) * max_perturbation_ratio
                random_noise_obs = np.random.uniform(-1, 1, current_observation.shape) * noise_magnitude_obs
                poisoned_observation = current_observation + random_noise_obs
                noise_magnitude_act = np.abs(current_action) * max_perturbation_ratio
                random_noise_act = np.random.uniform(-1, 1, current_action.shape) * noise_magnitude_act
                poisoned_action = current_action + random_noise_act
                poisoned_observations.append(poisoned_observation)
                poisoned_actions.append(poisoned_action)
            poisoned_data_list.append((poisoned_observations, poisoned_actions))
        
        poisoned_clusters_list = []    
        for poisoned_observations, poisoned_actions in poisoned_data_list:
            poisoned_combined = np.hstack((np.array(poisoned_observations), np.array(poisoned_actions)))
            poisoned_clusters = kmeans.predict(poisoned_combined)
            poisoned_clusters_list.append(poisoned_clusters.tolist())
        best_poisoned_idx = select_replacement_index(label_combinations, poisoned_clusters_list, slice_start_idx, time_slice_length)
        if best_poisoned_idx is not None:
            best_poisoned_observations, best_poisoned_actions = poisoned_data_list[best_poisoned_idx]
            for inner_idx, poisoned_obs in zip(range(slice_start_idx, slice_start_idx + time_slice_length), best_poisoned_observations):
                dataset.observations[inner_idx] = poisoned_obs
            for inner_idx, poisoned_act in zip(range(slice_start_idx, slice_start_idx + time_slice_length), best_poisoned_actions):
                dataset.actions[inner_idx] = poisoned_act

    return dataset

if __name__ == '__main__':

    poison()
