import os
import pathlib
from typing import Optional, Union

import gymnasium as gym
# from gym import spaces
import numpy as np

import torch as th
from torch.utils.data import DataLoader

from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
from stable_baselines3.common.vec_env import VecEnv, sync_envs_normalization

from ucrl.classify.classifier import TrajDFData, MujocoNPDataset, collate
from ucrl.common.evaluation import eval_policy_cost_traj
from ucrl.common.on_policy_algorithm import OnPolicyAlgorithmH, OnPolicyAlgorithmC

class EvalCostCallback(EvalCallback):

    def __init__(
            self,
            eval_env: Union[gym.Env, VecEnv],
            traj_path: str = None,
            # write_all_traj: bool = False,
            callback_on_new_best: Optional[BaseCallback] = None,
            callback_after_eval: Optional[BaseCallback] = None,
            n_eval_episodes: int = 5,
            eval_freq: int = 10000,
            log_path: Optional[str] = None,
            best_model_save_path: Optional[str] = None,
            # markov_cost: bool = False,
            # cvar_risk: float = None,
            deterministic: bool = True,
            render: bool = False,
            verbose: int = 1,
            warn: bool = True,
    ):
        super().__init__(eval_env=eval_env, callback_on_new_best=callback_on_new_best,
                         callback_after_eval=callback_after_eval, n_eval_episodes=n_eval_episodes,
                         eval_freq=eval_freq, log_path=log_path, best_model_save_path=best_model_save_path,
                         deterministic=deterministic, render=render, verbose=verbose, warn=warn)

        self.last_mean_cost = np.inf
        self.evaluations_costs = []
        self.traj_path = traj_path
        # self.write_all_traj = write_all_traj
        # self.markov_cost = markov_cost
        # self.cvar_risk = cvar_risk

    def _on_step(self) -> bool:
        continue_training = True

        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
        # if self.eval_freq > 0 and (self.n_calls == 1 or self.n_calls % self.eval_freq == 0):
            # Sync training and eval env if there is VecNormalize
            if self.model.get_vec_normalize_env() is not None:
                try:
                    sync_envs_normalization(self.training_env, self.eval_env)
                except AttributeError as e:
                    raise AssertionError(
                        "Training and eval env are not wrapped the same way, "
                        "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
                        "and warning above."
                    ) from e

            # Reset success rate buffer
            self._is_success_buffer = []

            (episode_rewards, episode_costs, episode_lengths, episode_safe_props, step_safe_props, episode_log_scores,
             trajectory_dataframes) = eval_policy_cost_traj(self.model,
                                                            self.eval_env,
                                                            n_eval_episodes=self.n_eval_episodes,
                                                            render=self.render,
                                                            deterministic=self.deterministic,
                                                            return_episode_rewards=True,
                                                            # markov_cost=self.markov_cost,
                                                            warn=self.warn,
                                                            callback=self._log_success_callback,
                                                            )

            # Write all trajectory dataframes into CSV files
            if self.traj_path is not None:
                for idx, dataframe in enumerate(trajectory_dataframes):
                    file_path = os.path.join(self.traj_path, str(self.n_calls))
                    pathlib.Path(file_path).mkdir(parents=True, exist_ok=True)
                    full_filename = os.path.join(file_path, str(idx + 1) + '.csv')
                    dataframe.to_csv(full_filename, index=False)

            if self.log_path is not None:
                assert isinstance(episode_rewards, list)
                assert isinstance(episode_costs, list)
                assert isinstance(episode_lengths, list)
                self.evaluations_timesteps.append(self.num_timesteps)
                self.evaluations_results.append(episode_rewards)
                self.evaluations_costs.append(episode_costs)
                self.evaluations_length.append(episode_lengths)

                kwargs = {}
                # Save success log if present
                if len(self._is_success_buffer) > 0:
                    self.evaluations_successes.append(self._is_success_buffer)
                    kwargs = dict(successes=self.evaluations_successes)

                np.savez(
                    self.log_path,
                    timesteps=self.evaluations_timesteps,
                    results=self.evaluations_results,
                    costs=self.evaluations_costs,
                    ep_lengths=self.evaluations_length,
                    **kwargs,
                )

            mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
            mean_cost, std_cost = np.mean(episode_costs), np.std(episode_costs)
            mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
            mean_safe_prop, std_safe_prop = np.mean(episode_safe_props), np.std(episode_safe_props)
            mean_step_safe_prop, std_step_safe_prop = np.mean(step_safe_props), np.std(step_safe_props)

            self.last_mean_reward = float(mean_reward)
            self.last_mean_cost = float(mean_cost)

            if self.verbose >= 1:
                print(
                    f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}, " f"episode_cost={mean_cost:.2f} +/- {std_cost:.2f}, " f"episode_safe_prop={mean_safe_prop:.2f} +/- {std_safe_prop:.2f}")
                print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
            # Add to current Logger
            self.logger.record("eval/mean_reward", float(mean_reward))
            # self.logger.record("eval/std_reward", float(std_reward))
            self.logger.record("eval/mean_cost", float(mean_cost))
            # self.logger.record("eval/std_cost", float(std_cost))
            self.logger.record("eval/mean_ep_length", mean_ep_length)
            self.logger.record("eval/mean_safe_prop", float(mean_safe_prop))
            # self.logger.record("eval/std_safe_prop", float(std_safe_prop))
            self.logger.record("eval/mean_step_safe_prop", float(mean_step_safe_prop))
            # self.logger.record("eval/std_step_safe_prop", float(std_step_safe_prop))
            if episode_log_scores is not None and len(episode_log_scores) > 0:
                mean_log_scores, std_log_scores = np.mean(episode_log_scores), np.std(episode_log_scores)
                mean_prob_scores, std_prob_scores = np.mean(np.exp(episode_log_scores)), np.std(np.exp(episode_log_scores))
                self.logger.record("eval/mean_log_scores", float(mean_log_scores))
                # self.logger.record("eval/std_log_scores", float(std_log_scores))
                self.logger.record("eval/mean_prob_scores", float(mean_prob_scores))
                # self.logger.record("eval/std_prob_scores", float(std_prob_scores))

            if len(self._is_success_buffer) > 0:
                success_rate = np.mean(self._is_success_buffer)
                if self.verbose >= 1:
                    print(f"Success rate: {100 * success_rate:.2f}%")
                self.logger.record("eval/success_rate", success_rate)

            # Dump log so the evaluation results are printed with the correct timestep
            self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
            self.logger.dump(self.num_timesteps)

            if mean_reward > self.best_mean_reward:
                if self.verbose >= 1:
                    print("New best mean reward!")
                if self.best_model_save_path is not None:
                    self.model.save(os.path.join(self.best_model_save_path, "best_model"))
                self.best_mean_reward = float(mean_reward)
                # Trigger callback on new best model, if needed
                if self.callback_on_new_best is not None:
                    continue_training = self.callback_on_new_best.on_step()

            # Trigger callback after every evaluation, if needed
            if self.callback is not None:
                continue_training = continue_training and self._on_event()

        return continue_training


class RetrainClassifierCallBack(EvalCallback):

    model: Union[OnPolicyAlgorithmH, OnPolicyAlgorithmC]
    train_dataset: MujocoNPDataset
    test_dataset: MujocoNPDataset

    def __init__(
            self,
            eval_env: Union[gym.Env, VecEnv],
            train_dataset: str,
            test_dataset: str,
            traj_path: str = None,
            retrain_pt_path: str = None,
            # write_all_traj: bool = False,
            callback_after_eval: Optional[BaseCallback] = None,
            n_eval_episodes: int = 1000,
            eval_freq: int = 500_000,
            log_path: Optional[str] = None,
            # markov_cost: bool = False,
            # cvar_risk: float = None,
            deterministic: bool = True,
            render: bool = False,
            verbose: int = 1,
            warn: bool = True,
    ):
        super().__init__(eval_env=eval_env, callback_on_new_best=None,
                         callback_after_eval=callback_after_eval, n_eval_episodes=n_eval_episodes,
                         eval_freq=eval_freq, log_path=log_path, best_model_save_path=None,
                         deterministic=deterministic, render=render, verbose=verbose, warn=warn)

        self.last_mean_cost = np.inf
        # self.evaluations_costs = []
        self.traj_path = traj_path
        self.retrain_pt_path = retrain_pt_path
        # self.write_all_traj = write_all_traj
        # self.markov_cost = markov_cost
        # self.cvar_risk = cvar_risk
        self.train_dataset = th.load(train_dataset)
        self.test_dataset = th.load(test_dataset)

    def _on_step(self) -> bool:
        continue_training = True

        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
        # if self.eval_freq > 0 and (self.n_calls == 1 or self.n_calls % self.eval_freq == 0):
            # Sync training and eval env if there is VecNormalize
            if self.model.get_vec_normalize_env() is not None:
                try:
                    sync_envs_normalization(self.training_env, self.eval_env)
                except AttributeError as e:
                    raise AssertionError(
                        "Training and eval env are not wrapped the same way, "
                        "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
                        "and warning above."
                    ) from e


            # Get data and retrain classifier

            (episode_rewards, episode_costs, episode_lengths, episode_safe_props, step_safe_props, episode_log_scores,
             trajectory_dataframes) = eval_policy_cost_traj(self.model,
                                                            self.eval_env,
                                                            n_eval_episodes=self.n_eval_episodes,
                                                            render=self.render,
                                                            deterministic=self.deterministic,
                                                            return_episode_rewards=True,
                                                            # markov_cost=self.markov_cost,
                                                            warn=self.warn,
                                                            callback=self._log_success_callback,
                                                            )

            # Write all trajectory dataframes into CSV files
            if self.traj_path is not None:
                for idx, dataframe in enumerate(trajectory_dataframes):
                    file_path = os.path.join(self.traj_path, str(self.n_calls))
                    pathlib.Path(file_path).mkdir(parents=True, exist_ok=True)
                    full_filename = os.path.join(file_path, str(idx + 1) + '.csv')
                    dataframe.to_csv(full_filename, index=False)

            # if self.log_path is not None:
            #     assert isinstance(episode_rewards, list)
            #     assert isinstance(episode_costs, list)
            #     assert isinstance(episode_lengths, list)
            #     self.evaluations_timesteps.append(self.num_timesteps)
            #     self.evaluations_results.append(episode_rewards)
            #     self.evaluations_costs.append(episode_costs)
            #     self.evaluations_length.append(episode_lengths)
            #
            #     kwargs = {}
            #     # Save success log if present
            #     if len(self._is_success_buffer) > 0:
            #         self.evaluations_successes.append(self._is_success_buffer)
            #         kwargs = dict(successes=self.evaluations_successes)
            #
            #     np.savez(
            #         self.log_path,
            #         timesteps=self.evaluations_timesteps,
            #         results=self.evaluations_results,
            #         costs=self.evaluations_costs,
            #         ep_lengths=self.evaluations_length,
            #         **kwargs,
            #     )

            trajectories_data = TrajDFData(trajectory_dataframes, self.eval_env.envs[0].unwrapped.spec.id)
            all_idx = np.arange(trajectories_data.get_num_traj())
            np.random.shuffle(all_idx)
            split_idx = int(len(all_idx) * 0.1)
            train_idx, test_idx = all_idx[split_idx:], all_idx[:split_idx]

            self.train_dataset.add_augment_data(np_data=trajectories_data, indices=train_idx)
            train_dataloader = DataLoader(self.train_dataset, batch_size=128, collate_fn=collate, shuffle=True)
            # train_dataset = MujocoNPDataset(mujoco_domain=self.eval_env.envs[0].unwrapped.spec.id,
            #                                 np_data=trajectories_data, indices=train_idx)
            # train_dataloader = DataLoader(train_dataset, batch_size=256, collate_fn=collate, shuffle=True)

            self.test_dataset.add_augment_data(np_data=trajectories_data, indices=test_idx)
            test_dataloader = DataLoader(self.test_dataset, batch_size=128, collate_fn=collate, shuffle=True)
            # test_dataset = MujocoNPDataset(mujoco_domain=self.eval_env.envs[0].unwrapped.spec.id,
            #                                np_data=trajectories_data, indices=test_idx)
            # test_dataloader = DataLoader(test_dataset, batch_size=256, collate_fn=collate, shuffle=True)

            # Unfreeze classifier param
            for param in self.model.classifier.parameters():
                param.requires_grad_(True)

            self.model.classifier.train()
            th.backends.cudnn.enabled = True

            # Retrain classifier for 10 epochs
            optimizer = th.optim.Adam(self.model.classifier.parameters(), lr=0.001)

            epoch = 0
            ave_training_loss, ave_valid_loss, valid_accuracy = float('Inf'), float('Inf'), float('-Inf')
            while valid_accuracy < 0.95 and epoch < 10:
                self.model.classifier.train()
                running_loss, running_loss_train = 0.0, 0.0
                for i, data in enumerate(train_dataloader, 0):
                    # get the inputs; data is a list of [inputs, labels]
                    inputs, labels, input_lengths = data
                    labels = labels.reshape(-1, 1)

                    loss, num_correct, num_tp, num_fp, num_tn, num_fn = (
                        self.model.classifier.forward_loss_metrics(inputs, labels.float(), input_lengths))

                    # outputs, dict_log_c_out, h_out, _ = self.model.classifier(inputs, input_lengths)
                    # loss = self.model.classifier.loss(outputs, labels.float())

                    # zero the parameter gradients
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    # print statistics
                    running_loss += loss.item()
                    running_loss_train += loss.item()
                    if i % 100 == 99:  # print every 2000 mini-batches
                        print('[%d, %5d] loss: %.4f' % (epoch + 1, i + 1, running_loss / 100))
                        running_loss = 0.0

                ave_training_loss = running_loss_train / len(train_dataloader)

                print('[%d] Training loss: %.4f' % (epoch + 1, ave_training_loss))

                self.model.classifier.eval()

                running_valid_loss, valid_correct = 0.0, 0
                valid_tp, valid_tn, valid_fp, valid_fn = 0, 0, 0, 0

                for j, valid_data in enumerate(test_dataloader, 0):
                    inputs_valid, labels_valid, input_lengths_valid = valid_data
                    labels_valid = labels_valid.reshape(-1, 1)
                    # labels_valid = labels_valid.float().reshape(-1, 1)

                    valid_loss, num_valid_correct, num_valid_tp, num_valid_fp, num_valid_tn, num_valid_fn = (
                        self.model.classifier.forward_loss_metrics(inputs_valid, labels_valid.float(), input_lengths_valid))

                    # valid_outputs, dict_valid_log_c_out, valid_h_out, _ = self.model.classifier(inputs_valid, input_lengths_valid)
                    # valid_loss = self.model.classifier.loss(valid_outputs, labels_valid.float())

                    running_valid_loss += valid_loss.item()

                    # labels_valid = labels_valid.bool()
                    valid_correct += num_valid_correct
                    valid_tp += num_valid_tp
                    valid_fp += num_valid_fp
                    valid_tn += num_valid_tn
                    valid_fn += num_valid_fn

                ave_valid_loss = running_valid_loss / len(test_dataloader)
                valid_accuracy = valid_correct / len(test_dataloader.dataset)
                valid_precision = valid_tp / (valid_tp + valid_fp)
                valid_recall = valid_tp / (valid_tp + valid_fn)

                print('[%d] Validation loss: %.4f' % (epoch + 1, ave_valid_loss))
                print(
                    f"Validation Error: \n Accuracy: {(100 * valid_accuracy):>0.1f}%, Avg loss: {ave_valid_loss:>8f} \n")
                print(f"Validation Error: \n Precision: {(100 * valid_precision):>0.1f}% \n")
                print(f"Validation Error: \n Recall: {(100 * valid_recall):>0.1f}% \n")
                epoch += 1

            # Freeze classifier param
            for param in self.model.classifier.parameters():
                param.requires_grad_(False)

            self.model.classifier.eval()
            th.backends.cudnn.enabled = False

            if self.retrain_pt_path is not None:
                pathlib.Path(self.retrain_pt_path).mkdir(parents=True, exist_ok=True)
                th.save(self.model.classifier.state_dict(), os.path.join(self.retrain_pt_path,
                                                                         'Classifier-' + str(self.n_calls) + '.pt'))

            mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
            mean_cost, std_cost = np.mean(episode_costs), np.std(episode_costs)
            mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
            mean_safe_prop, std_safe_prop = np.mean(episode_safe_props), np.std(episode_safe_props)
            mean_step_safe_prop, std_step_safe_prop = np.mean(step_safe_props), np.std(step_safe_props)

            self.last_mean_reward = float(mean_reward)
            self.last_mean_cost = float(mean_cost)

            if self.verbose >= 1:
                print(
                    f"Retrain num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}, " f"episode_cost={mean_cost:.2f} +/- {std_cost:.2f}, " f"episode_safe_prop={mean_safe_prop:.2f} +/- {std_safe_prop:.2f}")
                print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
            # Add to current Logger
            self.logger.record("classifier/mean_training_loss", ave_training_loss)
            self.logger.record("classifier/mean_validation_loss", ave_valid_loss)
            self.logger.record("classifier/validation_accuracy", valid_accuracy)
            self.logger.record("classifier/validation_precision", valid_precision)
            self.logger.record("classifier/validation_recall", valid_recall)

            # self.logger.record("classifier/mean_reward", float(mean_reward))
            # # self.logger.record("classifier/std_reward", float(std_reward))
            # self.logger.record("classifier/mean_cost", float(mean_cost))
            # # self.logger.record("classifier/std_cost", float(std_cost))
            # self.logger.record("classifier/mean_ep_length", mean_ep_length)
            # self.logger.record("classifier/mean_safe_prop", float(mean_safe_prop))
            # # self.logger.record("classifier/std_safe_prop", float(std_safe_prop))
            # self.logger.record("classifier/mean_step_safe_prop", float(mean_step_safe_prop))
            # # self.logger.record("classifier/std_step_safe_prop", float(std_step_safe_prop))
            # if episode_log_scores is not None:
            #     mean_log_scores, std_log_scores = np.mean(episode_log_scores), np.std(episode_log_scores)
            #     mean_prob_scores, std_prob_scores = np.mean(np.exp(episode_log_scores)), np.std(np.exp(episode_log_scores))
            #     self.logger.record("classifier/mean_log_scores", float(mean_log_scores))
            #     # self.logger.record("classifier/std_log_scores", float(std_log_scores))
            #     self.logger.record("classifier/mean_prob_scores", float(mean_prob_scores))
            #     # self.logger.record("classifier/std_prob_scores", float(std_prob_scores))

            # Dump log so the evaluation results are printed with the correct timestep
            self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
            self.logger.dump(self.num_timesteps)

            # Trigger callback after every evaluation, if needed
            if self.callback is not None:
                continue_training = continue_training and self._on_event()

        return continue_training
