import numpy as np
import logging
from tqdm import tqdm
import pickle
from d3rlpy.dataset import (
    ReplayBuffer,
    FIFOBuffer,
    MDPDataset,
)
from d3rlpy.algos import CQLConfig
from d3rlpy.preprocessing import MinMaxActionScaler


logger = logging.getLogger(__name__)


def _split_into_trajectories(positions, terminals):

    indices = np.where(terminals)[0]
    indices = np.append(-1, indices)
    positions_ = []
    for start, stop in zip(indices[:-1], indices[1:]):

        positions_.append(positions[(start+1):(stop+1)])
    return positions_


def _compute_returns(positions, gamma=0.99):
    """
    Computes the discounted return based on the distance to the optimal position.
    """
    G = 0
    for t in reversed(range(len(positions))):
        G = -np.linalg.norm(positions[t]) + gamma * G
    return G


def compute_trajectory_quality(collection):
    """
    Currently, this is based on the environment reward, not the (gt) positioning
    """

    positions_per_episode = _split_into_trajectories(collection.positions, collection.terminals)

    return np.array([_compute_returns(p) for p in positions_per_episode]).mean()


def compute_state_exploration(collection, lam=0.1):
    """
    Count how many hypercubes of size `lam` are covered by the given points.

    Args:
        positions (np.ndarray): An (n, d) array with positions in [0,1]^d.
        lam (float): The side length of the hypercubes.

    Returns:
        int: Number of unique cubes covered by at least one point.
    """
    positions = collection.positions

    assert lam > 0

    # Compute the cube index for each position
    cube_indices = np.floor(positions / lam).astype(int)

    # Find unique cubes that contain at least one point
    unique_cubes = np.unique(cube_indices, axis=0)

    return len(unique_cubes)



class DataCollection:

    def __init__(self, observations, actions, rewards, terminals, positions, is_augmented=None):

        if len(observations.shape)==3:
            observations = observations[:,np.newaxis]
        if len(observations.shape)==3 and observations.shape[3]==3:
            observations = np.transpose(observations, (0, 3, 1, 2))

        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.terminals = terminals
        self.positions = positions
        self.is_augmented=is_augmented


        self.data = MDPDataset(
            observations=self.observations,
            actions=self.actions,
            rewards=self.rewards,
            terminals=self.terminals,
        )

    def save(self, path):
        """
        Save the data collection to a file.
        """
        data = {
            'observations': self.observations,
            'actions': self.actions,
            'rewards': self.rewards,
            'terminals': self.terminals,
            'positions': self.positions,
            'is_augmented': self.is_augmented
        }
        with open(path, 'wb') as f:
            pickle.dump(data, f)

    @classmethod
    def load(cls, path):
        """
        Load a data collection from a file.
        """

        with open(path, 'rb') as f:
            data = pickle.load(f)
        return cls(**data)


class Evaluation:
    """
    Class to evaluate the performance of an augmentor in terms of data collection and model
    training.

    Args:
        env (Environment): The environment to collect data from.
        expert_policy (Policy): The expert policy to collect data from.
        n_episodes_per_collection (int): Number of episodes to collect per collection run.
        n_episodes_per_model_eval (int): Number of episodes to evaluate each model.
        n_trained_models (int): Number of models to train per collection run.
        n_collection_runs (int): Number of collection runs to perform.
        training_args (dict): Arguments for training the model.
        model_args (dict): Arguments for the model configuration.
        eval_reward_fn (callable): Function to evaluate the reward during evaluation.

    """

    def __init__(
        self,
        env,
        expert_policy,
        n_episodes_per_collection=10,
        n_episodes_per_model_eval=10,
        n_trained_models=2,
        n_collection_runs=2,
        training_args=None,
        model_args=None,
        max_transitions=10_000,
        device='cpu',
    ):

        if training_args is None:
            training_args = {
                "n_steps": int(1e5),
                "n_steps_per_epoch": 500,
            }

        self.training_args = training_args

        if model_args is None:
            model_args = {
                "actor_learning_rate": 0.0001,
                "critic_learning_rate": 0.0001,
                "conservative_weight": 0.1,
                "alpha_threshold": 1,
            }

        self.model_args = model_args

        self.max_transitions = max_transitions

        self.env = env
        self.expert_policy = expert_policy
        self.n_episodes_per_collection = n_episodes_per_collection
        self.n_episodes_per_model_eval = n_episodes_per_model_eval
        self.n_trained_models = n_trained_models
        self.n_collection_runs = n_collection_runs
        self.device = device

    def _fit_model_offline(self, collection):

        dataset = ReplayBuffer(
            buffer=FIFOBuffer(limit=self.max_transitions),
            episodes=collection.data.episodes,
        )

        model = CQLConfig(
            action_scaler=MinMaxActionScaler(),
            **self.model_args,
        ).create(device=self.device)

        model.fit(dataset, **self.training_args)
        return model

    def _eval_model_online(self, model, steps=20):

        scores_all_runs = []

        for _ in range(self.n_episodes_per_model_eval):
            obs, _ = self.env.reset()
            scores = []

            for _ in range(steps):

                action = (model.sample_action(obs[np.newaxis])+model.sample_action(obs[np.newaxis]))/2

                obs, _ , _ , _ , _ = self.env.step(action[0])

                scores.append(np.linalg.norm(self.env.get_position_diff_to_optimum()))

            scores_all_runs.append(scores)
        return scores_all_runs

    def _score_dataset(self, collection):
        """
        Compute (absolute) scores for a given dataset. Need to be normalized by the expert
        policy (i.e., using the VoidAugmentation).
        """
        return {
            "trajectory_quality": compute_trajectory_quality(collection),
            "state_exploration": compute_state_exploration(collection),
        }

    def evaluate_augmentor(self, augmentor, train_models=True):

        scores = {
            'data': [],
            'models': [],
        }

        datasets = []

        for n_collection in range(self.n_collection_runs):
            logger.info(f"Collecting data {n_collection}/{self.n_collection_runs}")
            data = self.collect_data(augmentor)
            datasets.append(data)
            logger.info("Compute scores")
            scores["data"].append(self._score_dataset(data))

            if train_models:
                scores_model = []
                for n_model in range(self.n_trained_models):
                    logger.info(f"Train model {n_model}/{self.n_trained_models}")
                    model = self._fit_model_offline(data)
                    logger.info(f"Evaluate model {n_model}/{self.n_trained_models}")
                    scores_model.append(self._eval_model_online(model))

                scores["models"].append(scores_model)

        return scores, datasets

    def collect_data(self, augmentor):

        augmentor.init()

        for _ in tqdm(range(self.n_episodes_per_collection), ncols=100):
            obs, _ = self.env.reset()
            augmentor.run(obs)

            if augmentor.observations.shape[0] > self.max_transitions:
                break

        terminals = augmentor.terminals[:self.max_transitions]
        terminals[-1] = 1  # Ensure the last transition is terminal
        return DataCollection(
            observations=augmentor.observations[:self.max_transitions],
            actions=augmentor.actions[:self.max_transitions],
            rewards=augmentor.rewards[:self.max_transitions],
            terminals=terminals,
            positions=augmentor.positions[:self.max_transitions],
            is_augmented=augmentor.is_augmented[:self.max_transitions],
        )
