import d3rlpy
import torch
import copy
import gym
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from d3rlpy.metrics.scorer import evaluate_on_environment
from typing import Sequence
from abc import ABCMeta, abstractmethod
from d3rlpy.models.torch.encoders import Encoder
from d3rlpy.algos import BC, BCQ, BEAR, CQL
from d3rlpy.models.torch.encoders import _VectorEncoder

def assign_episode_to_dataset(dataset):
    episode_ids = []
    current_episode_id = 0
    for terminal in dataset.terminals:
        episode_ids.append(current_episode_id)
        if terminal:
            current_episode_id += 1
    return np.array(episode_ids)

def observation_addperturb(observations, max_perturbation_ratio):
    perturbation_factor = np.random.uniform(low=-1, high=1, size=observations.shape)
    perturbation = perturbation_factor / np.linalg.norm(perturbation_factor, ord=np.inf) * max_perturbation_ratio * observations
    perturbed_observation = observations + perturbation
    return perturbed_observation

def action_addperturb(actions, max_perturbation_ratio):
    perturbation_factor = np.random.uniform(low=-1, high=1, size=actions.shape)
    perturbation = perturbation_factor / np.linalg.norm(perturbation_factor, ord=np.inf) * max_perturbation_ratio * actions
    perturbed_actions = actions + perturbation
    perturbed_actions = np.clip(perturbed_actions, -1, 1)
    return perturbed_actions


def poison():
    dataset, env = d3rlpy.datasets.get_d4rl('walker2d-medium-v0')
    # dataset, env = d3rlpy.datasets.get_d4rl('hopper-medium-v0')
    # dataset, env = d3rlpy.datasets.get_d4rl('halfcheetah-medium-v0')
    scorer = evaluate_on_environment(env)

    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)

    episodes = assign_episode_to_dataset(dataset)

    k = 9  
    max_perturbation_ratio = 0.05 
    time_slice_length = 9

    observations_float64 = np.array(dataset.observations).astype(np.float64) 
    actions_float64 = np.array(dataset.actions).astype(np.float64)
    ObservationsActions = np.hstack((observations_float64, actions_float64))

    kmeans = KMeans(n_clusters=k, random_state=42)
    cluster_labels = kmeans.fit_predict(ObservationsActions)
    print(('succes'))
    results = []
    for i in range(len(cluster_labels) - time_slice_length + 1):
        current_slice_labels = cluster_labels[i:i + time_slice_length]
        unique_slice_labels = [current_slice_labels[0]]
        for label in current_slice_labels[1:]:
            if label != unique_slice_labels[-1]:
                unique_slice_labels.append(label)
        coverage = len(set(unique_slice_labels)) / time_slice_length
        results.append({
            'Starting Point': i,
            'Label Combination': ''.join([str(label) for label in unique_slice_labels]),
            'Count': 1,  
            'Coverage': coverage
        })

    results_df = pd.DataFrame(results)
    grouped = results_df.groupby(['Label Combination', 'Coverage', 'Starting Point'])
    count_df = grouped.size().reset_index(name='Count')
    sorted_df = count_df.sort_values(by=['Coverage', 'Count'], ascending=[False, True]) 
    PTarget = np.zeros(len(cluster_labels), dtype=int)
    selected_indices = set()
    selected_data_points = 0  
    selected_time_steps = []
    for _, row in sorted_df.iterrows():
        start_idx = row['Starting Point']  
        end_idx = start_idx + time_slice_length
        time_steps = list(range(start_idx, end_idx))
        if not any(ts in selected_time_steps for ts in time_steps):
            selected_time_steps.extend(time_steps)
            for ts in time_steps:
                PTarget[ts] = 1
            selected_data_points += time_slice_length
            if len(selected_time_steps) >= 10000:
                break
    
    for idx in selected_indices:
        dataset.observations[idx] = observation_addperturb(dataset.observations[idx], max_perturbation_ratio)
        dataset.actions[idx] = action_addperturb(dataset.actions[idx], max_perturbation_ratio)
    return dataset

if __name__ == '__main__':

    poison()
