"""
Implicit diffusion Q-learning (IDQL) trainer for diffusion policy.

Do not support pixel input right now.

"""

from collections import deque
from util.timer import Timer
import hydra
import os
import pickle
import einops
import numpy as np
import torch
import logging
import wandb
from copy import deepcopy
import torchjd
from torchjd.aggregation import UPGrad

from util.scheduler import CosineAnnealingWarmupRestarts
from agent.finetune.train_agent import TrainAgent
from agent.ri.utils import StdScheduler, RIMixin

log = logging.getLogger(__name__)


class TrainIDQLDiffusionAgent(TrainAgent, RIMixin):

    def __init__(self, cfg):
        super().__init__(cfg)

        # note the discount factor gamma here is applied to reward every act_steps, instead of every env step
        self.gamma = cfg.train.gamma

        # Wwarm up period for critic before actor updates
        self.n_critic_warmup_itr = cfg.train.n_critic_warmup_itr

        # Optimizer
        self.actor_optimizer = torch.optim.AdamW(
            self.model.actor.parameters(),
            lr=cfg.train.actor_lr,
            weight_decay=cfg.train.actor_weight_decay,
        )
        self.actor_lr_scheduler = CosineAnnealingWarmupRestarts(
            self.actor_optimizer,
            first_cycle_steps=cfg.train.actor_lr_scheduler.first_cycle_steps,
            cycle_mult=1.0,
            max_lr=cfg.train.actor_lr,
            min_lr=cfg.train.actor_lr_scheduler.min_lr,
            warmup_steps=cfg.train.actor_lr_scheduler.warmup_steps,
            gamma=1.0,
        )
        self.critic_q_optimizer = torch.optim.AdamW(
            self.model.critic_q.parameters(),
            lr=cfg.train.critic_lr,
            weight_decay=cfg.train.critic_weight_decay,
        )
        self.critic_v_optimizer = torch.optim.AdamW(
            self.model.critic_v.parameters(),
            lr=cfg.train.critic_lr,
            weight_decay=cfg.train.critic_weight_decay,
        )
        self.critic_v_lr_scheduler = CosineAnnealingWarmupRestarts(
            self.critic_v_optimizer,
            first_cycle_steps=cfg.train.critic_lr_scheduler.first_cycle_steps,
            cycle_mult=1.0,
            max_lr=cfg.train.critic_lr,
            min_lr=cfg.train.critic_lr_scheduler.min_lr,
            warmup_steps=cfg.train.critic_lr_scheduler.warmup_steps,
            gamma=1.0,
        )
        self.critic_q_lr_scheduler = CosineAnnealingWarmupRestarts(
            self.critic_q_optimizer,
            first_cycle_steps=cfg.train.critic_lr_scheduler.first_cycle_steps,
            cycle_mult=1.0,
            max_lr=cfg.train.critic_lr,
            min_lr=cfg.train.critic_lr_scheduler.min_lr,
            warmup_steps=cfg.train.critic_lr_scheduler.warmup_steps,
            gamma=1.0,
        )

        # Buffer size
        self.buffer_size = cfg.train.buffer_size

        # Actor params
        self.use_expectile_exploration = cfg.train.use_expectile_exploration

        # Scaling reward
        self.scale_reward_factor = cfg.train.scale_reward_factor

        # Updates
        self.replay_ratio = cfg.train.replay_ratio
        self.critic_tau = cfg.train.critic_tau

        # Whether to use deterministic mode when sampling at eval
        self.eval_deterministic = cfg.train.get("eval_deterministic", False)

        # Sampling
        self.num_sample = cfg.train.eval_sample_num

        # RI
        self.rl_per_bc = cfg.train.rl_per_bc
        self.num_il_epochs = cfg.train.num_il_epochs
        log.info(f"IL epochs: {self.num_il_epochs} | RL per BC: {self.rl_per_bc}")
        
        self.il_max_grad_norm = cfg.il_train.get("max_grad_norm", float("inf"))
        log.info(f"IL max grad norm: {self.il_max_grad_norm} | RL max grad norm: {self.max_grad_norm} | RL per BC: {self.rl_per_bc} | Num IL epochs: {self.num_il_epochs}")
        self.gd_norm_upper_ratio = cfg.il_train.get("gd_norm_upper_ratio", 1.5)
        self.gd_norm_lower_ratio = cfg.il_train.get("gd_norm_lower_ratio", 0.8)
        self.gd_norm_base_alpha = cfg.il_train.get("gd_norm_base_alpha", 0.95)
        self.gd_norm_base = None
        self.gd_op = cfg.il_train.get("grad_op", None)
        
        self.init_il()
        
        self.std_scheduler = StdScheduler(cfg)
        
    def run(self):
        # make a FIFO replay buffer for obs, action, and reward
        obs_buffer = deque(maxlen=self.buffer_size)
        next_obs_buffer = deque(maxlen=self.buffer_size)
        action_buffer = deque(maxlen=self.buffer_size)
        reward_buffer = deque(maxlen=self.buffer_size)
        terminated_buffer = deque(maxlen=self.buffer_size)

        # Start training loop
        timer = Timer()
        run_results = []
        cnt_train_step = 0
        last_itr_eval = False
        done_venv = np.zeros((1, self.n_envs))
        
        # RI
        rl_iter = 0
        il_iterator = iter(self.dataloader_train)
        aggregator = UPGrad()

        while self.itr < self.n_train_itr:
            # Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
            options_venv = [{} for _ in range(self.n_envs)]
            if self.itr % self.render_freq == 0 and self.render_video:
                for env_ind in range(self.n_render):
                    options_venv[env_ind]["video_path"] = os.path.join(
                        self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
                    )

            # Define train or eval - all envs restart
            eval_mode = self.itr % self.val_freq == 0 and not self.force_train
            self.model.eval() if eval_mode else self.model.train()
            last_itr_eval = eval_mode

            # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
            firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
            if self.reset_at_iteration or eval_mode or last_itr_eval:
                prev_obs_venv = self.reset_env_all(options_venv=options_venv)
                firsts_trajs[0] = 1
            else:
                # if done at the end of last iteration, the envs are just reset
                firsts_trajs[0] = done_venv
            reward_trajs = np.zeros((self.n_steps, self.n_envs))

            # Collect a set of trajectories from env
            for step in range(self.n_steps):
                # Select action
                with torch.no_grad():
                    cond = {
                        "state": torch.from_numpy(prev_obs_venv["state"])
                        .float()
                        .to(self.device)
                    }
                    samples = (
                        self.model(
                            cond=cond,
                            deterministic=eval_mode and self.eval_deterministic,
                            num_sample=self.num_sample,
                            use_expectile_exploration=self.use_expectile_exploration,
                        )
                        .cpu()
                        .numpy()
                    )  # n_env x horizon x act
                action_venv = samples[:, : self.act_steps]

                # Apply multi-step action
                obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
                    self.venv.step(action_venv)
                )
                done_venv = terminated_venv | truncated_venv
                reward_trajs[step] = reward_venv
                firsts_trajs[step + 1] = done_venv

                # add to buffer
                if not eval_mode:
                    obs_venv_copy = obs_venv.copy()
                    for i in range(self.n_envs):
                        if truncated_venv[i]:
                            obs_venv_copy["state"][i] = info_venv[i]["final_obs"][
                                "state"
                            ]
                    obs_buffer.append(prev_obs_venv["state"])
                    next_obs_buffer.append(obs_venv_copy["state"])
                    action_buffer.append(action_venv)
                    reward_buffer.append(reward_venv * self.scale_reward_factor)
                    terminated_buffer.append(terminated_venv)

                # update for next step
                prev_obs_venv = obs_venv

                # count steps --- not acounting for done within action chunk
                cnt_train_step += self.n_envs * self.act_steps if not eval_mode else 0

            # Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
            episodes_start_end = []
            for env_ind in range(self.n_envs):
                env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
                for i in range(len(env_steps) - 1):
                    start = env_steps[i]
                    end = env_steps[i + 1]
                    if end - start > 1:
                        episodes_start_end.append((env_ind, start, end - 1))
            if len(episodes_start_end) > 0:
                reward_trajs_split = [
                    reward_trajs[start: end + 1, env_ind]
                    for env_ind, start, end in episodes_start_end
                ]
                num_episode_finished = len(reward_trajs_split)
                episode_reward = np.array(
                    [np.sum(reward_traj) for reward_traj in reward_trajs_split]
                )
                episode_best_reward = np.array(
                    [
                        np.max(reward_traj) / self.act_steps
                        for reward_traj in reward_trajs_split
                    ]
                )
                avg_episode_reward = np.mean(episode_reward)
                avg_best_reward = np.mean(episode_best_reward)
                success_rate = np.mean(
                    episode_best_reward >= self.best_reward_threshold_for_success
                )
            else:
                episode_reward = np.array([])
                num_episode_finished = 0
                avg_episode_reward = 0
                avg_best_reward = 0
                success_rate = 0
                log.info("[WARNING] No episode completed within the iteration!")

            il_losses = []
            il_values = []
            # Update models
            if not eval_mode:
                num_batch = int(
                    self.n_steps * self.n_envs / self.batch_size * self.replay_ratio
                )

                obs_trajs = np.array(deepcopy(obs_buffer))
                action_trajs = np.array(deepcopy(action_buffer))
                next_obs_trajs = np.array(deepcopy(next_obs_buffer))
                reward_trajs = np.array(deepcopy(reward_buffer))
                terminated_trajs = np.array(deepcopy(terminated_buffer))

                # flatten
                obs_trajs = einops.rearrange(
                    obs_trajs,
                    "s e h d -> (s e) h d",
                )
                next_obs_trajs = einops.rearrange(
                    next_obs_trajs,
                    "s e h d -> (s e) h d",
                )
                action_trajs = einops.rearrange(
                    action_trajs,
                    "s e h d -> (s e) h d",
                )
                reward_trajs = reward_trajs.reshape(-1)
                terminated_trajs = terminated_trajs.reshape(-1)
                for _ in range(num_batch):
                    if self.rl_per_bc != -1 and rl_iter % self.rl_per_bc == 0:
                        try:
                            il_batch = next(il_iterator)
                        except StopIteration:
                            il_iterator = iter(self.dataloader_train)
                            il_batch = next(il_iterator)
                        il_loss = self.get_il_loss(il_batch)
                        il_value = self.get_il_value(il_batch)
                        log.info(f"IL loss {il_loss.item():8.4f} | IL value {il_value.item():8.4f}")
                        il_losses.append(il_loss.item())
                        il_values.append(il_value.item())
                        # self.rl_per_bc = min((rl_iter // 100) + 1, 30)

                    # Sample batch
                    inds = np.random.choice(len(obs_trajs), self.batch_size)
                    obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device)
                    next_obs_b = (
                        torch.from_numpy(next_obs_trajs[inds]).float().to(self.device)
                    )
                    actions_b = (
                        torch.from_numpy(action_trajs[inds]).float().to(self.device)
                    )
                    reward_b = (
                        torch.from_numpy(reward_trajs[inds]).float().to(self.device)
                    )
                    terminated_b = (
                        torch.from_numpy(terminated_trajs[inds]).float().to(self.device)
                    )

                    # update critic value function
                    critic_loss_v = self.model.loss_critic_v(
                        {"state": obs_b}, actions_b
                    )
                    self.critic_v_optimizer.zero_grad()
                    critic_loss_v.backward()
                    self.critic_v_optimizer.step()

                    # update critic q function
                    critic_loss_q = self.model.loss_critic_q(
                        {"state": obs_b},
                        {"state": next_obs_b},
                        actions_b,
                        reward_b,
                        terminated_b,
                        self.gamma,
                    )
                    self.critic_q_optimizer.zero_grad()
                    critic_loss_q.backward()
                    self.critic_q_optimizer.step()

                    # update target q function
                    self.model.update_target_critic(self.critic_tau)
                    loss_critic = critic_loss_q.detach() + critic_loss_v.detach()

                    # Update policy with collected trajectories - no weighting
                    loss_actor = self.model.loss(
                        actions_b,
                        {"state": obs_b},
                    )
                    self.actor_optimizer.zero_grad()
                    if self.rl_per_bc != -1 and rl_iter % self.rl_per_bc == 0:
                        torchjd.backward([il_loss, loss_actor], aggregator)
                    else:
                        loss_actor.backward()
                    if self.itr >= self.n_critic_warmup_itr:
                        if self.max_grad_norm is not None:
                            torch.nn.utils.clip_grad_norm_(
                                self.model.actor.parameters(), self.max_grad_norm
                            )
                        self.actor_optimizer.step()
                    rl_iter += 1
                log.info(f"RL iter {rl_iter} | loss actor {loss_actor.item():8.4f} | loss critic {loss_critic.item():8.4f}")
                log.info(f"IL loss {il_loss.item():8.4f} | IL value {il_value.item():8.4f}")

            # Update lr
            self.actor_lr_scheduler.step()
            self.critic_v_lr_scheduler.step()
            self.critic_q_lr_scheduler.step()

            # Save model
            if self.itr % self.save_model_freq == 0 or self.itr == self.n_train_itr - 1:
                self.save_model()
            if self.rl_per_bc != -1 and self.std_scheduler.expl_schedule != None:
                self.model.set_min_sampling_denoising_std(
                    self.std_scheduler.step(itr=self.itr, current_il_val=np.mean(il_values))
                )

            # Log loss and save metrics
            run_results.append(
                {
                    "itr": self.itr,
                    "step": cnt_train_step,
                }
            )
            if self.itr % self.log_freq == 0:
                time = timer()
                run_results[-1]["time"] = time
                if eval_mode:
                    log.info(
                        f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
                    )
                    if self.use_wandb:
                        wandb.log(
                            {
                                "success rate - eval": success_rate,
                                "avg episode reward - eval": avg_episode_reward,
                                "avg best reward - eval": avg_best_reward,
                                "num episode - eval": num_episode_finished,
                            },
                            step=self.itr,
                            commit=False,
                        )
                    run_results[-1]["eval_success_rate"] = success_rate
                    run_results[-1]["eval_episode_reward"] = avg_episode_reward
                    run_results[-1]["eval_best_reward"] = avg_best_reward
                else:
                    log.info(
                        f"{self.itr}: step {cnt_train_step:8d} | loss actor {loss_actor:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
                    )
                    avg_il_value = np.mean(il_values)
                    if self.use_wandb:
                        wandb.log(
                            {
                                "total env step": cnt_train_step,
                                "loss - actor": loss_actor,
                                "loss - critic": loss_critic,
                                "il loss": np.mean(il_losses),
                                "il value": avg_il_value,
                                "il value ema": self.std_scheduler.il_val_ema,
                                "avg episode reward - train": avg_episode_reward,
                                "num episode - train": num_episode_finished,
                                # "rl per bc": self.rl_per_bc,
                            },
                            step=self.itr,
                            commit=True,
                        )
                    run_results[-1]["train_episode_reward"] = avg_episode_reward
                with open(self.result_path, "wb") as f:
                    pickle.dump(run_results, f)
            self.itr += 1
