import ray
import pickle
from d3rlpy.dataset import ReplayBufferBase
from d3rlpy.metrics import EvaluatorProtocol
from d3rlpy.interface import QLearningAlgoProtocol
from d3rlpy.preprocessing import MinMaxActionScaler
from pathlib import Path
import d3rlpy
from d3rlpy.dataset import FIFOBuffer
from d3rlpy.datasets import (
    ReplayBuffer,
    InfiniteBuffer,
)
import json
import pandas as pd
import glob
from typing import (
    Annotated,
    Optional,
)
from typer import (
    run,
    Option,
    Context,
)
import os
import numpy as np
from lift.config import Config
from lift.network import SimpleEncoderFactory
from d3rlpy.models.encoders import PixelEncoderFactory
from collect_augmentor_data import load_env

from lift.evaluation import DataCollection
from lift.policy_augmentations import Shortcuts


working_dir = os.environ["WORKING_DIR"]


def _generate_local_action_grid(grid_size=10, extent=0.2):
    """Return a (grid_size**2, 2) array of 2D action offsets in [-extent, extent]^2."""
    x = np.linspace(-extent, extent, grid_size)
    y = np.linspace(-extent, extent, grid_size)
    xv, yv = np.meshgrid(x, y)
    actions = np.stack([xv.flatten(), yv.flatten()], axis=-1)
    return actions


def compute_qvalues(
    model,
    state_grid_x=np.linspace(0, 1, 10),
    state_grid_y=np.linspace(0, 1, 10),
    action_grid_size=100,
):

    action_extent = (state_grid_x.max() - state_grid_x.min()) / (2*len(state_grid_x))
    """Visualize Q(s,a) as a local action-grid-based density around each state."""
    states = np.array([[x, y] for x in state_grid_x for y in state_grid_y])

    actions = _generate_local_action_grid(action_grid_size, action_extent)
    num_actions = actions.shape[0]
    all_qs = []
    for s in states:
        s_batch = np.repeat(s[None, :], num_actions, axis=0)
        q_vals = model.predict_value(s_batch, actions)
        all_qs.append(q_vals)

    return states, np.array(all_qs)

def compute_greedy_vector_field(
    model,
    state_grid_x=np.linspace(0, 1, 20),
    state_grid_y=np.linspace(0, 1, 20),
):
    xv, yv = np.meshgrid(state_grid_x, state_grid_y)
    states = np.stack([xv.flatten(), yv.flatten()], axis=-1)  # shape: (N, 2)

    actions = model.predict(states)
    return states, actions


class SaveCallback():
    def __init__(self, path, save_interval=1, with_qvalues=False, with_gvf=False):
        self.mode_save_dir = path
        self.with_qvalues = with_qvalues
        self.with_gvf = with_gvf
        self.interval = save_interval

    def __call__(self, algo, epoch, total_steps):
        if epoch % self.interval == 0:
            filepath_m = os.path.join(self.mode_save_dir, f'epoch_{epoch}.pt')
            algo.save_model(filepath_m)

            if self.with_qvalues:
                filepath_q = os.path.join(self.mode_save_dir, f'q_values_{epoch}.pkl')

                s, q = compute_qvalues(algo)

                with open(filepath_q, 'wb') as f:
                    pickle.dump({
                        'states': s,
                        'q_values': q
                    }, f)

            if self.with_gvf:
                filepath_q = os.path.join(self.mode_save_dir, f'gvf_{epoch}.pkl')

                s, a = compute_greedy_vector_field(algo)

                with open(filepath_q, 'wb') as f:
                    pickle.dump({
                        'states': s,
                        'actions': a
                    }, f)


class CustomEnvironmentEvaluator(EvaluatorProtocol):
    r"""Action matches between algorithms.

    Args:
        env: Gym environment.
        n_trials: Number of episodes to evaluate.
        epsilon: Probability of random action.
    """
    _n_trials: int
    _epsilon: float

    def __init__(
        self,
        env,
        n_trials: int = 10,
        epsilon: float = 0.0,
        steps: int =30
    ):
        self._env = env
        self._n_trials = n_trials
        self._epsilon = epsilon
        self._steps = steps

    def __call__(
        self, algo: QLearningAlgoProtocol, dataset: ReplayBufferBase
    ) -> float:
        episode_rewards = []
        for _ in range(self._n_trials):
            observation, _ = self._env.reset()
            episode_reward = []

            for _ in range(self._steps):
                if isinstance(observation, np.ndarray):
                    observation = np.expand_dims(observation, axis=0)
                elif isinstance(observation, (tuple, list)):
                    observation = [
                        np.expand_dims(o, axis=0) for o in observation
                    ]
                else:
                    raise ValueError(
                        f"Unsupported observation type: {type(observation)}"
                    )
                action = algo.predict(observation)[0]

                observation, reward, terminated, truncated, _ = self._env.step(action)
                episode_reward.append(float(reward))

            episode_rewards.append(np.max(episode_reward))

        return float(np.median(episode_rewards))


def eval_model_online(env, model, n_episodes_per_model_eval=20, steps=30):

    scores_all_runs = []
    for _ in range(n_episodes_per_model_eval):
        obs, _ = env.reset()
        scores = []

        for _ in range(steps):

            action = model.predict(obs[np.newaxis])

            obs, reward, terminated, truncated, _ = env.step(action[0])

            if hasattr(env, 'get_position_diff_to_optimum'):
                scores.append(np.linalg.norm(env.get_position_diff_to_optimum()))
            else:
                scores.append(np.linalg.norm(env.env.get_position_diff_to_optimum()))

        scores_all_runs.append(scores)
    return scores_all_runs


def _get_model_path(data_path, c_weight='', action_scaler=False, use_shortcuts=False):
    action_scaler_str = 's_' if action_scaler else ''
    if use_shortcuts:
        if '_void_' in data_path:
            data_path = data_path.replace('_void_', f'_void_{use_shortcuts}_')
        elif 'lift' in data_path:
            data_path = data_path.replace('lift', f'lift_{use_shortcuts}')
        else:
            data_path = data_path.replace('_data_', f'{use_shortcuts}_data_')
    return data_path.replace('_data_', f'_model{c_weight}_{action_scaler_str}').replace('.pkl', '')


def _get_score_path(data_path, c_weight='', action_scaler=False, use_shortcuts=False):
    action_scaler_str = 's_' if action_scaler else ''
    if use_shortcuts:
        if '_void_' in data_path:
            data_path = data_path.replace('_void_', f'_void_{use_shortcuts}_')
        elif 'lift' in data_path:
            data_path = data_path.replace('lift', f'lift_{use_shortcuts}')
        else:
            data_path = data_path.replace('_data_', f'{use_shortcuts}_data_')

    return data_path.replace('_data_', f'_scores{c_weight}_{action_scaler_str}').replace('.pkl', '.json')

def get_last_created_folder(directory):
    path = Path(directory)
    folders = [p for p in path.iterdir() if p.is_dir()]
    folders.sort(key=lambda p: p.stat().st_ctime, reverse=True)

    if len(folders) == 0:
        raise ValueError("No subfolders found")

    return folders[0]

def load_best_model(model, log_path):
    log_path_scores = get_last_created_folder(log_path)
    df_scores = pd.read_csv(log_path_scores / 'environment.csv', header=None, names=['epoch', 'steps', 'score'])

    best_epoch = df_scores.set_index('epoch').score.idxmax()

    path_load = os.path.join(log_path, f"epoch_{best_epoch}.pt")
    model.load_model(path_load)
    return model

def _fd_count():
    try:
        return len(os.listdir(f"/proc/{os.getpid()}/fd"))
    except Exception:
        return -1

@ray.remote
def train_on_file(params, data_path,c_weight, action_scaler=False, use_shortcuts=False):
    print(f"[PID {os.getpid()}] start FD={_fd_count()} file={data_path}")
    params["online"] = False
    _, model_args, training_args = Config.setup_params(params)

    path_scores = _get_score_path(data_path, c_weight, action_scaler,use_shortcuts)
    path_model_logs = _get_model_path(data_path, c_weight, action_scaler,use_shortcuts)
    path_tf_logs = os.path.join(working_dir,'lift/offlineRL_logs', params['env'])

    collection = DataCollection.load(data_path)

    dataset = ReplayBuffer(
        buffer=InfiniteBuffer(),
        transition_picker=Shortcuts(select=use_shortcuts) if use_shortcuts else None,
        episodes=collection.data.episodes,
    )
    n_actions=collection.data.episodes[0].actions.shape[-1]
    print(collection.data.episodes[0].observations.shape)

    env = load_env(
        {
            "env": params["env"],
            "n_actions": n_actions,
            "distortion": data_path.split(f"{params['env']}_{n_actions}_")[1].split('/')[0],
        }
    )
    obs, _ = env.reset()
    model_args["actor_learning_rate"] = 0.001
    model_args["critic_learning_rate"] = 0.001
    model_args["conservative_weight"] = c_weight

    if action_scaler:
        model_args["action_scaler"]=MinMaxActionScaler(
            minimum=env.action_space.low.tolist(),
            maximum=env.action_space.high.tolist(),
        )
    n_epochs = 40
    n_steps_per_epoch = 400

    if params["env"]=='po':
        encoder_factory = SimpleEncoderFactory(10)
        device='cpu'
    else:
        encoder_factory = PixelEncoderFactory()
        device='cuda:0'

    model = Config.algorithms[training_args["model"]]["cls"](
        actor_encoder_factory=encoder_factory,
        critic_encoder_factory=encoder_factory,
        batch_size=500, 
        **model_args,
    ).create(device=device)


    logger_adapter = d3rlpy.logging.CombineAdapterFactory([
       d3rlpy.logging.FileAdapterFactory(root_dir=path_model_logs),
       d3rlpy.logging.TensorboardAdapterFactory(root_dir=path_tf_logs),
    ])

    debug = False

    model.fit(
        dataset,
        n_steps=n_epochs*n_steps_per_epoch,
        n_steps_per_epoch=n_steps_per_epoch,
        logger_adapter=logger_adapter,
        epoch_callback=SaveCallback(
            path_model_logs,
            with_qvalues=debug,
            with_gvf=debug,
        ),
        evaluators={
            'environment': CustomEnvironmentEvaluator(env, n_trials=20)
        },
    )

    model = load_best_model(model, path_model_logs)

    scores = {"evaluation_scores": eval_model_online(env, model, n_episodes_per_model_eval=20, steps=40)}

    with open(path_scores, 'w') as f:
        json.dump(scores, f)

def main(
    env: Annotated[str, Option(help="Environment used")] = 'po',
    model: Annotated[str, Option(help="Model")] = 'CQL',
    data_path: Annotated[
        Optional[str], Option(help="path to dataset")
    ] = "lift_eval",
    ctx: Context = Option(None, hidden=True),
):
    params = ctx.params

    data_path = os.path.join(os.environ["WORKING_DIR"], params.get('data_path'))

    data_files = glob.glob(data_path + f'/*{env}*/*_data_*.pkl')

    print(data_files)

    ray.init(address='auto')
    tasks = []
    num_gpus = 0 if params['env']=='po' else 1

    for c_weight in [5]:
        for use_shortcuts in [False]:
            for action_scaler in [False]:
                   tasks.extend(
                        [train_on_file.options(num_cpus=1, num_gpus=num_gpus).remote(
                        params,
                        file,
                        c_weight=c_weight,
                        action_scaler=action_scaler,
                        use_shortcuts=use_shortcuts)
                for file in data_files])

    ray.get(tasks)

if __name__ == '__main__':
    run(main)
