import d3rlpy
import torch
import copy
import gym
import argparse
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from d3rlpy.metrics.scorer import evaluate_on_environment
from typing import Sequence
from abc import ABCMeta, abstractmethod
from d3rlpy.models.torch.encoders import Encoder
from d3rlpy.algos import BCQ, BEAR, CQL
from d3rlpy.models.torch import (
    DeterministicPolicy,
    EnsembleContinuousQFunction,
    EnsembleQFunction,
    Policy,
)

def assign_episode_to_dataset(dataset):
    episode_ids = []
    current_episode_id = 0
    for terminal in dataset.terminals:
        episode_ids.append(current_episode_id)
        if terminal:
            current_episode_id += 1
    return np.array(episode_ids)


class QFunction(nn.Module):
    def __init__(self, input_size, output_size):
        super(QFunction, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, output_size)
        )

    def forward(self, x, action=None):
        # Modify the forward method to accept action as an argument
        return self.network(x)

class EnsembleQFunction(nn.Module):
    def __init__(self, input_size, output_size, num_ensembles):
        super(EnsembleQFunction, self).__init__()
        self._q_funcs = nn.ModuleList([QFunction(input_size, output_size) for _ in range(num_ensembles)])
        self._action_size = output_size

    def forward(self, x):
        values = [q_func(x).view(1, x.shape[0], 1) for q_func in self._q_funcs]
        return torch.cat(values, dim=0).mean(dim=0)


class EnsembleContinuousQFunction(EnsembleQFunction):
    def __init__(self, input_size, output_size, num_ensembles):
        super(EnsembleContinuousQFunction, self).__init__(input_size, output_size, num_ensembles)

        # Define the neural network architecture
        self.network = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, output_size)
        )

    def forward(self, x, action, reduction="mean"):
        x = torch.cat([x, action], dim=-1)
        values = [q_func(x, action).view(1, x.shape[0], 1) for q_func in self._q_funcs]
        return self._reduce_ensemble(torch.cat(values, dim=0), reduction)

    def _reduce_ensemble(self, values, reduction):
        if reduction == "mean":
            return values.mean(dim=0)
        elif reduction == "min":
            return values.min(dim=0).values
        elif reduction == "max":
            return values.max(dim=0).values
        else:
            raise ValueError(f"Invalid reduction option: {reduction}")

    def __call__(self, x, action, reduction="mean"):
        return super().__call__(x, action, reduction)

    def compute_target(self, x, action, reduction="min", lam=0.75):
        return self._compute_target(x, action, reduction, lam)
        
        
def observation_addperturb(observations, max_perturbation_ratio):
    perturbation_factor = np.random.uniform(low=-1, high=1, size=observations.shape)
    perturbation_factor = np.clip(perturbation_factor, -1, 1)
    perturbation_factor /= np.linalg.norm(perturbation_factor, axis=-1, ord=np.inf, keepdims=True)
    perturbation_factor *= max_perturbation_ratio
    perturbed_observation = observations + perturbation_factor
    return perturbed_observation

def action_addperturb(actions, max_perturbation_ratio):
    perturbation_factor = np.random.uniform(low=-1, high=1, size=actions.shape)
    perturbation_factor = np.clip(perturbation_factor, -1, 1)
    perturbation_factor /= np.linalg.norm(perturbation_factor, axis=-1, ord=np.inf, keepdims=True)
    perturbation_factor *= max_perturbation_ratio
    perturbed_observation = actions + perturbation_factor
    return perturbed_observation

def poison():
    dataset, env = d3rlpy.datasets.get_d4rl('walker2d-medium-v0')

    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)

    episodes = assign_episode_to_dataset(dataset)
    path = "../walker2d_meduim_model_cql.pt"
    
    state_dict = torch.load(path)

    input_size = 17 + 6 
    output_size = 1  
    num_q_funcs = 3  
    length = 5

    q_function = EnsembleContinuousQFunction(input_size, output_size, num_q_funcs)  
    q_function.load_state_dict(state_dict, strict=False)
    q_function.eval() 

    q_values_list = []
    with torch.no_grad():
        for i, (observations, actions) in enumerate(zip(dataset.observations, dataset.actions)):
            state_tensor = torch.tensor(observations, dtype=torch.float32).unsqueeze(0)
            action_tensor = torch.tensor(actions, dtype=torch.float32).unsqueeze(0)
            q_value = q_function(state_tensor, action_tensor, reduction="min")
            q_values_list.append(q_value.cpu().numpy())
    print("success")

    selected_time_steps = set()
    selected_time_slices = []
    for start_idx in range(len(q_values_list)):
        if start_idx + length > len(q_values_list):
            continue  
        slice_probs = q_values_list[start_idx:start_idx + length]
        current_q_value = np.mean(slice_probs)  
        selected_time_slices.append((start_idx, length, current_q_value))

    selected_time_slices.sort(key=lambda x: x[2], reverse=True)
    selected_indices = [] 

    num_data_points_selected = 0
    for idx, length, _ in selected_time_slices:
        if num_data_points_selected + length > 10000:
            break
        if not any(idx in range(s, s + l) for s, l in selected_indices):
            selected_indices.append((idx, length))
            num_data_points_selected += length

    for idx, length in selected_indices:  
        for i in range(idx, idx + length): 
            dataset.observations[i] = observation_addperturb(dataset.observations[i], 0.05)
            dataset.actions[i] = action_addperturb(dataset.actions[i], 0.05)
    return dataset

if __name__ == '__main__':

    poison()
