import ray
import pickle
from tqdm.auto import trange
from typing_extensions import Self
from typer import (
    run,
    Option,
    Context,
)
import os
import numpy as np

from d3rlpy.dataset import ReplayBufferBase
from d3rlpy.metrics import EvaluatorProtocol
from d3rlpy.interface import QLearningAlgoProtocol
from d3rlpy.preprocessing import MinMaxActionScaler
from pathlib import Path
import d3rlpy
import json
import pandas as pd
import glob
from typing import (
    Annotated,
    Optional,
    Callable,
)
from d3rlpy.models.encoders import PixelEncoderFactory
from d3rlpy.base import save_config
from d3rlpy.constants import (
    LoggingStrategy,
)
from d3rlpy.dataset import create_fifo_replay_buffer
from d3rlpy.logging import (
    LOG,
    D3RLPyLogger,
    FileAdapterFactory,
    LoggerAdapterFactory,
)
from d3rlpy.metrics import evaluate_qlearning_with_environment
from d3rlpy.types import GymEnv
from d3rlpy.algos.utility import (
    assert_action_space_with_env,
    build_scalers_with_env,
)
from d3rlpy.algos.qlearning.explorers import Explorer

from lift.network import SimpleEncoderFactory
from lift.evaluation import DataCollection
from collect_augmentor_data import load_env


working_dir = os.environ["WORKING_DIR"]

def custom_fit_online(
        model,
        env: GymEnv,
        n_episodes=100,
        buffer: Optional[ReplayBufferBase] = None,
        explorer: Optional[Explorer] = None,
        n_steps: int = 1000000,
        n_steps_per_epoch: int = 10000,
        update_interval: int = 1,
        n_updates: int = 1,
        update_start_step: int = 0,
        random_steps: int = 0,
        eval_env: Optional[GymEnv] = None,
        eval_epsilon: float = 0.0,
        eval_n_trials: int = 10,
        save_interval: int = 1,
        experiment_name: Optional[str] = None,
        with_timestamp: bool = True,
        logging_steps: int = 500,
        logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
        logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
        show_progress: bool = True,
        callback: Optional[Callable[[Self, int, int], None]] = None,
    ) -> None:
        episodes=0
        # create default replay buffer
        if buffer is None:
            buffer = create_fifo_replay_buffer(1000000, env=env)

        # check action-space
        assert_action_space_with_env(model, env)

        # initialize algorithm parameters
        build_scalers_with_env(model, env)

        # setup algorithm
        if model.impl is None:
            LOG.debug("Building model...")
            model.build_with_env(env)
            LOG.debug("Model has been built.")
        else:
            LOG.warning("Skip building models since they're already built.")

        # setup logger
        if experiment_name is None:
            experiment_name = model.__class__.__name__ + "_online"
        logger = D3RLPyLogger(
            algo=model,
            adapter_factory=logger_adapter,
            experiment_name=experiment_name,
            n_steps_per_epoch=n_steps_per_epoch,
            with_timestamp=with_timestamp,
        )

        # save hyperparameters
        save_config(model, logger)

        # switch based on show_progress flag
        xrange = trange if show_progress else range

        # start training loop
        observation, _ = env.reset()
        rollout_return = 0.0
        for total_step in xrange(1, n_steps + 1):
            with logger.measure_time("step"):
                # sample exploration action
                with logger.measure_time("inference"):
                    if total_step < random_steps:
                        action = env.action_space.sample()
                    elif explorer:
                        x = observation.reshape((1,) + observation.shape)
                        action = explorer.sample(model, x, total_step)[0]
                    else:
                        action = model.sample_action(
                            np.expand_dims(observation, axis=0)
                        )[0]

                # step environment
                with logger.measure_time("environment_step"):
                    (
                        next_observation,
                        reward,
                        terminal,
                        truncated,
                        _,
                    ) = env.step(action)
                    rollout_return += float(reward)

                clip_episode = terminal or truncated

                # store observation
                buffer.append(observation, action, float(reward))

                # reset if terminated
                if clip_episode:
                    buffer.clip_episode(terminal)
                    observation, _ = env.reset()
                    episodes=+1
                    logger.add_metric("rollout_return", rollout_return)
                    rollout_return = 0.0
                else:
                    observation = next_observation

                # psuedo epoch count
                epoch = total_step // n_steps_per_epoch

                if (
                    total_step > update_start_step
                    and buffer.transition_count > model.batch_size
                ):
                    if total_step % update_interval == 0:
                        for _ in range(n_updates):  # controls UTD ratio
                            # sample mini-batch
                            with logger.measure_time("sample_batch"):
                                batch = buffer.sample_transition_batch(
                                    model.batch_size
                                )

                            # update parameters
                            with logger.measure_time("algorithm_update"):
                                loss = model.update(batch)

                            # record metrics
                            for name, val in loss.items():
                                logger.add_metric(name, val)

                        if (
                            logging_strategy == LoggingStrategy.STEPS
                            and total_step % logging_steps == 0
                        ):
                            logger.commit(epoch, total_step)

                # call callback if given
                if callback:
                    callback(model, epoch, total_step)

            if epoch > 0 and total_step % n_steps_per_epoch == 0:
                # evaluation
                if eval_env:
                    eval_score = evaluate_qlearning_with_environment(
                        model,
                        eval_env,
                        n_trials=eval_n_trials,
                        epsilon=eval_epsilon,
                    )
                    logger.add_metric("evaluation", eval_score)

                if epoch % save_interval == 0:
                    logger.save_model(total_step, model)

                # save metrics
                if logging_strategy == LoggingStrategy.EPOCH:
                    logger.commit(epoch, total_step)
            if episodes>=n_episodes:
                break
        # clip the last episode
        buffer.clip_episode(False)

        # close logger
        logger.close()


def _generate_local_action_grid(grid_size=10, extent=0.2):
    """Return a (grid_size**2, 2) array of 2D action offsets in [-extent, extent]^2."""
    x = np.linspace(-extent, extent, grid_size)
    y = np.linspace(-extent, extent, grid_size)
    xv, yv = np.meshgrid(x, y)
    actions = np.stack([xv.flatten(), yv.flatten()], axis=-1)
    return actions


def compute_qvalues(
    model,
    state_grid_x=np.linspace(0, 1, 10),
    state_grid_y=np.linspace(0, 1, 10),
    action_grid_size=100,
):

    action_extent = (state_grid_x.max() - state_grid_x.min()) / (2*len(state_grid_x))
    """Visualize Q(s,a) as a local action-grid-based density around each state."""
    states = np.array([[x, y] for x in state_grid_x for y in state_grid_y])

    actions = _generate_local_action_grid(action_grid_size, action_extent)
    num_actions = actions.shape[0]
    all_qs = []
    for s in states:
        s_batch = np.repeat(s[None, :], num_actions, axis=0)
        q_vals = model.predict_value(s_batch, actions)
        all_qs.append(q_vals)

    return states, np.array(all_qs)

def compute_greedy_vector_field(
    model,
    state_grid_x=np.linspace(0, 1, 20),
    state_grid_y=np.linspace(0, 1, 20),
):
    xv, yv = np.meshgrid(state_grid_x, state_grid_y)
    states = np.stack([xv.flatten(), yv.flatten()], axis=-1)  # shape: (N, 2)

    actions = model.predict(states)
    return states, actions



class SaveCallback():
    def __init__(self, path, save_interval=1, with_qvalues=False, with_gvf=False):
        self.mode_save_dir = path
        self.with_qvalues = with_qvalues
        self.with_gvf = with_gvf
        self.interval = save_interval

    def __call__(self, algo, epoch, total_steps):
        if epoch % self.interval == 0:
            filepath_m = os.path.join(self.mode_save_dir, f'epoch_{epoch}.pt')
            algo.save_model(filepath_m)

            if self.with_qvalues:
                filepath_q = os.path.join(self.mode_save_dir, f'q_values_{epoch}.pkl')

                s, q = compute_qvalues(algo)

                with open(filepath_q, 'wb') as f:
                    pickle.dump({
                        'states': s,
                        'q_values': q
                    }, f)

            if self.with_gvf:
                filepath_q = os.path.join(self.mode_save_dir, f'gvf_{epoch}.pkl')

                s, a = compute_greedy_vector_field(algo)

                with open(filepath_q, 'wb') as f:
                    pickle.dump({
                        'states': s,
                        'actions': a
                    }, f)


class CustomEnvironmentEvaluator(EvaluatorProtocol):
    r"""Action matches between algorithms.

    Args:
        env: Gym environment.
        n_trials: Number of episodes to evaluate.
        epsilon: Probability of random action.
    """
    _n_trials: int
    _epsilon: float

    def __init__(
        self,
        env,
        n_trials: int = 10,
        epsilon: float = 0.0,
        steps: int =30
    ):
        self._env = env
        self._n_trials = n_trials
        self._epsilon = epsilon
        self._steps = steps

    def __call__(
        self, algo: QLearningAlgoProtocol, dataset: ReplayBufferBase
    ) -> float:
        episode_rewards = []
        for _ in range(self._n_trials):
            observation, _ = self._env.reset()
            episode_reward = []

            #while not done:
            for _ in range(self._steps):
                if isinstance(observation, np.ndarray):
                    observation = np.expand_dims(observation, axis=0)
                elif isinstance(observation, (tuple, list)):
                    observation = [
                        np.expand_dims(o, axis=0) for o in observation
                    ]
                else:
                    raise ValueError(
                        f"Unsupported observation type: {type(observation)}"
                    )
                action = algo.predict(observation)[0]

                observation, reward, terminated, truncated, _ = self._env.step(action)
                episode_reward.append(float(reward))


            # Position of nearest point to optimum
            episode_rewards.append(np.max(episode_reward))

        # Among all nearest points, return the median
        return float(np.median(episode_rewards))


def eval_model_online(env, model, n_episodes_per_model_eval=20, steps=30):

    scores_all_runs = []
    for _ in range(n_episodes_per_model_eval):
        obs, _ = env.reset()
        scores = []

        for _ in range(steps):

            #action = (model.sample_action(obs[np.newaxis])+model.sample_action(obs[np.newaxis]))/2
            action = model.predict(obs[np.newaxis])

            obs, reward, terminated, truncated, _ = env.step(action[0])

            if hasattr(env, 'get_position_diff_to_optimum'):
                scores.append(np.linalg.norm(env.get_position_diff_to_optimum()))
            else:
                scores.append(np.linalg.norm(env.env.get_position_diff_to_optimum()))

            #if terminated or truncated:
            #    break

        scores_all_runs.append(scores)
    return scores_all_runs


def _get_model_path(data_path, c_weight='', action_scaler=False, cumulate_actions=False):
    action_scaler_str = 's_' if action_scaler else ''

    if '_void_' in data_path:
        data_path = data_path.replace('_void_', '_void_online_')
    else:
        assert 'only void data'


    return data_path.replace('_data_', f'_model{c_weight}_{action_scaler_str}').replace('.pkl', '')


def _get_score_path(data_path, c_weight='', action_scaler=False, cumulate_actions=False):
    action_scaler_str = 's_' if action_scaler else ''

    if '_void_' in data_path:
        data_path = data_path.replace('_void_', '_void_online_')
    else:
        assert 'only void data'

    return data_path.replace('_data_', f'_scores{c_weight}_{action_scaler_str}').replace('.pkl', '.json')



def get_last_created_folder(directory):
    path = Path(directory)
    # Filter only directories
    folders = [p for p in path.iterdir() if p.is_dir()]
    # Sort by creation time (latest last)
    folders.sort(key=lambda p: p.stat().st_ctime, reverse=True)

    if len(folders) == 0:
        raise ValueError("No subfolders found")

    return folders[0]

def load_best_model(model, log_path):
    log_path_scores = get_last_created_folder(log_path)
    df_scores = pd.read_csv(log_path_scores / 'environment.csv', header=None, names=['epoch', 'steps', 'score'])

    best_epoch = df_scores.set_index('epoch').score.idxmax()

    path_load = os.path.join(log_path, f"epoch_{best_epoch}.pt")
    model.load_model(path_load)
    return model

def _fd_count():
    try:
        return len(os.listdir(f"/proc/{os.getpid()}/fd"))
    except Exception:
        return -1

@ray.remote
def train_on_file(params, data_path, c_weight=5, action_scaler=False, cumulate_actions=False, n_episodes=100):
    print(f"[PID {os.getpid()}] start FD={_fd_count()} file={data_path}")
    params["online"] = False

    path_scores = _get_score_path(data_path, c_weight, action_scaler,cumulate_actions)
    path_model_logs = _get_model_path(data_path, c_weight, action_scaler,cumulate_actions)
    path_tf_logs = os.path.join(working_dir,'lift/offlineRL_logs', '14' )

    if n_episodes==100:
        collection = DataCollection.load(data_path.replace("100", "50"))
        n_episodes-=50
    elif n_episodes==500:
        collection = DataCollection.load(data_path.replace("500", "100"))
        n_episodes-=100

    episodes = collection.data.episodes
    dataset =  d3rlpy.dataset.create_fifo_replay_buffer(
            limit=10_000,
            episodes=episodes,
        )
    n_actions = collection.data.episodes[0].actions.shape[-1]

    env = load_env(
        {
            "env": params['env'],
            "n_actions": n_actions,
            "distortion": data_path.split(f"{params['env']}_{n_actions}_")[1].split('/')[0],
        }
    )

    model_args = { }
    model_args["actor_learning_rate"] = 0.001
    model_args["critic_learning_rate"] = 0.001
    model_args["action_scaler"]=MinMaxActionScaler(
            minimum=env.action_space.low.tolist(),
            maximum=env.action_space.high.tolist(),
     )
    n_steps_per_epoch = 100
    if params["env"]=='po':
        encoder_factory = SimpleEncoderFactory(10)
        device='cpu'
    else:
        encoder_factory = PixelEncoderFactory()
        device='cuda:0'

    model = d3rlpy.algos.SACConfig(
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        batch_size=256,
        **model_args,
    ).create(device=device)


    logger_adapter = d3rlpy.logging.CombineAdapterFactory([
       d3rlpy.logging.FileAdapterFactory(root_dir=path_model_logs),
       d3rlpy.logging.TensorboardAdapterFactory(root_dir=path_tf_logs),
    ])

    online_buffer = d3rlpy.dataset.create_fifo_replay_buffer(
        limit=20_000,
        env=env,
        )
    model.collect(
        env=env,
        buffer=online_buffer,
        deterministic=True,
        n_steps=100
    )
    n_episodes-=1

    replay_buffer = d3rlpy.dataset.MixedReplayBuffer(
        primary_replay_buffer=online_buffer,
        secondary_replay_buffer=dataset,
        secondary_mix_ratio=0.5,
        )

    custom_fit_online(
        model,
        n_episodes=n_episodes,
        env=env,
        buffer=replay_buffer,
        logger_adapter=logger_adapter,
        n_steps=n_episodes*n_steps_per_epoch,
        n_steps_per_epoch=n_steps_per_epoch,
        random_steps=1_000,
        experiment_name=os.path.basename(path_model_logs),
        n_updates=5,
    )

    scores = {"evaluation_scores": eval_model_online(env, model, n_episodes_per_model_eval=20, steps=40)}

    with open(path_scores, 'w') as f:
        json.dump(scores, f)


def main(
    env: Annotated[str, Option(help="Environment used")] = 'po',
    n_episodes: Annotated[int, Option(help="Environment used")] = 500,
    data_path: Annotated[
        Optional[str], Option(help="path to dataset")
    ] = "eval",

    ctx: Context = Option(None, hidden=True),
):
    params = ctx.params

    # Validate data_path
    data_path = os.path.join(os.environ["WORKING_DIR"], params.get('data_path'))

    data_files = glob.glob(data_path + '/*/*_data_*.pkl')

    print(data_files)

    ray.init(address='auto')
    tasks = []

    num_gpus = 0 if params['env'] == 'po' else 1

    tasks.extend([
        train_on_file.options(num_cpus=1, num_gpus=num_gpus).remote(
            params,
            file,
            n_episodes=params.get('n_episodes'),
        )
        for file in data_files
        ])


    ray.get(tasks)

if __name__ == '__main__':
    run(main)
