from copy import deepcopy
import os
import hydra
import pickle
import einops
import numpy as np
import torch
import logging
import wandb
import math

from agent.pretrain.train_diffusion_agent import batch_to_device
from agent.finetune.train_ppo_diffusion_agent import TrainPPODiffusionAgent
from util.timer import Timer
from util.scheduler import CosineAnnealingWarmupRestarts

log = logging.getLogger(__name__)

class TrainDiffusionRIAgent(TrainPPODiffusionAgent):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.rl_per_bc = cfg.train.rl_per_bc  # Frequency of IL iterations
        self.num_il_epochs = cfg.train.num_il_epochs  # Number of IL epochs per iteration
        self.il_max_grad_norm = cfg.il_train.get("max_grad_norm", float("inf"))
        log.info(f"IL max grad norm: {self.il_max_grad_norm} | RL max grad norm: {self.max_grad_norm} | RL per BC: {self.rl_per_bc} | Num IL epochs: {self.num_il_epochs}")
        self.gd_norm_upper_ratio = cfg.il_train.get("gd_norm_upper_ratio", 1.5)
        self.gd_norm_lower_ratio = cfg.il_train.get("gd_norm_lower_ratio", 0.8)
        self.gd_norm_base_alpha = cfg.il_train.get("gd_norm_base_alpha", 0.95)
        self.gd_norm_base = None
        self.gd_op = cfg.il_train.get("grad_op", None)
        self.init_il()

    def init_il(self):
        self.il_batch_size = self.cfg.il_train.batch_size
        self.dataset_train = hydra.utils.instantiate(self.cfg.il_dataset)
        self.dataloader_train = torch.utils.data.DataLoader(
            self.dataset_train,
            batch_size=self.il_batch_size,
            num_workers=4 if self.dataset_train.device == "cpu" else 0,
            shuffle=True,
            pin_memory=True if self.dataset_train.device == "cpu" else False,
        )
        self.dataloader_val = None
        if "train_split" in self.cfg.il_train and self.cfg.il_train.train_split < 1:
            val_indices = self.dataset_train.set_train_val_split(self.cfg.il_train.train_split)
            self.dataset_val = deepcopy(self.dataset_train)
            self.dataset_val.set_indices(val_indices)
            self.dataloader_val = torch.utils.data.DataLoader(
                self.dataset_val,
                batch_size=self.il_batch_size,
                num_workers=4 if self.dataset_val.device == "cpu" else 0,
                shuffle=True,
                pin_memory=True if self.dataset_val.device == "cpu" else False,
            )

        self.il_optimizer = torch.optim.AdamW(
            self.model.actor_ft.parameters(),
            lr=self.cfg.il_train.learning_rate,
            weight_decay=self.cfg.il_train.weight_decay,
        )
        self.il_lr_scheduler = CosineAnnealingWarmupRestarts(
            self.il_optimizer,
            first_cycle_steps=self.cfg.il_train.lr_scheduler.first_cycle_steps,
            cycle_mult=1.0,
            max_lr=self.cfg.il_train.learning_rate,
            min_lr=self.cfg.il_train.lr_scheduler.min_lr,
            warmup_steps=self.cfg.il_train.lr_scheduler.warmup_steps,
            gamma=1.0,
        )

    def train_il(self, num_epochs):
        """Perform IL training for a specified number of epochs."""
        if num_epochs == 0:
            log.info("Skipping IL training")
            return
        timer = Timer()
        loss_il_epoch = []
        loss_val_epoch = []
        grad_norms = []
        effective_grad_norms = []
        batch_cnt = 0
        break_flag = False

        log.info(f"IL training: {self.itr} itr | num_epochs: {num_epochs} | batch_size: {self.batch_size} | num_batches: {len(self.dataloader_train)}")

        for epoch in range(num_epochs):
            for batch in self.dataloader_train:
                if self.dataset_train.device == "cpu":
                    batch = batch_to_device(batch)
                self.model.actor_ft.train()
                loss_il = self.model.il_loss(*batch)
                self.il_optimizer.zero_grad()
                loss_il.backward()
                # Compute gradient norm before optimizer step
                grad_norm_il = torch.sqrt(
                    sum(p.grad.detach().pow(2).sum() for p in self.model.actor_ft.parameters() 
                        if p.grad is not None)
                ).item()
                if self.cfg.il_train.get("zero_grad", False):
                    self.il_optimizer.zero_grad()
                elif self.gd_op == "trust":
                    # Only use IL trusted batches
                    if self.gd_norm_base is not None:
                        gd_up = self.gd_norm_upper_ratio * self.gd_norm_base
                        gd_down = self.gd_norm_lower_ratio * self.gd_norm_base
                    else:
                        gd_up = float("inf")
                        gd_down = float("-inf")
                    if grad_norm_il > gd_down:
                        self.il_optimizer.zero_grad()
                        continue
                    batch_cnt += 1
                    self.il_optimizer.step()
                    effective_grad_norms.append(grad_norm_il)
                elif self.gd_op == "clamp":
                    # Clamp IL gradients
                    if self.gd_norm_base is not None:
                        gd_up = self.gd_norm_upper_ratio * self.gd_norm_base
                    else:
                        gd_up = float("inf")
                    clipped_grad_norm_il = torch.nn.utils.clip_grad_norm_(
                        self.model.actor_ft.parameters(), min(self.il_max_grad_norm, gd_up)
                    )
                    batch_cnt += 1
                    self.il_optimizer.step()
                    effective_grad_norms.append(clipped_grad_norm_il.item())
                elif self.gd_op == "break":
                    # Early stop if gradient norm becomes too low
                    gd_low = self.gd_norm_base * self.gd_norm_lower_ratio if self.gd_norm_base is not None else float("-inf")
                    if grad_norm_il < gd_low:
                        break_flag = True
                    if break_flag:
                        self.il_optimizer.zero_grad()
                    else:
                        batch_cnt += 1
                        self.il_optimizer.step()
                        effective_grad_norms.append(grad_norm_il)
                else:
                    # Standard gradient update
                    clipped_grad_norm_il = torch.nn.utils.clip_grad_norm_(
                        self.model.actor_ft.parameters(), self.il_max_grad_norm
                    )
                    batch_cnt += 1
                    self.il_optimizer.step()
                    effective_grad_norms.append(clipped_grad_norm_il.item())

                grad_norms.append(grad_norm_il)
                loss_il_epoch.append(loss_il.item())
                self.il_lr_scheduler.step()

            # Validate
            if self.dataloader_val is not None and epoch % self.val_freq == 0:
                self.model.actor_ft.eval()
                for batch_val in self.dataloader_val:
                    if self.dataset_val.device == "cpu":
                        batch_val = batch_to_device(batch_val)
                    with torch.no_grad():
                        loss_val, infos_val = self.model.il_loss(*batch_val)
                    loss_val_epoch.append(loss_val.item())
                self.model.actor_ft.train()

            log.info(f"IL epoch {epoch}: loss {np.mean(loss_il_epoch):8.4f} | t:{timer():8.4f}")

        avg_loss_il = np.mean(loss_il_epoch)
        loss_val = np.mean(loss_val_epoch) if len(loss_val_epoch) > 0 else None
        avg_grad_norm_il = np.mean(grad_norms)
        effective_avg_grad_norm_il = np.mean(effective_grad_norms)
        
        if self.gd_norm_base is None:
            self.gd_norm_base = avg_grad_norm_il
        else:
            self.gd_norm_base = self.gd_norm_base_alpha * self.gd_norm_base + (1 - self.gd_norm_base_alpha) * avg_grad_norm_il

        log.info(f"IL completed: train loss {avg_loss_il:8.4f} | batch_cnt: {batch_cnt} | t:{timer():8.4f}")
        gd_up = self.gd_norm_upper_ratio * self.gd_norm_base
        gd_down = self.gd_norm_lower_ratio * self.gd_norm_base
        log.info(f"gd norm {avg_grad_norm_il:8.4f} | base {self.gd_norm_base:8.4f} | up {gd_up:8.4f} | low {gd_down:8.4f}")
        if self.use_wandb:
            if loss_val is not None:
                wandb.log(
                    {"loss - val": loss_val}, step=self.itr, commit=False
                )
            wandb.log(
                {
                    "loss - train": avg_loss_il,
                    "IL grad_norm": avg_grad_norm_il,
                    "IL lr": self.il_optimizer.param_groups[0]["lr"],
                    "IL used batch": batch_cnt,
                    "IL base norm": self.gd_norm_base,
                    "IL effective grad norm": effective_avg_grad_norm_il,
                },
                step=self.itr,
                commit=True,
            )

    def run(self):
        # Start training loop
        timer = Timer()
        run_results = []
        cnt_train_step = 0
        last_itr_eval = False
        done_venv = np.zeros((1, self.n_envs))
        rl_iter = 0
        while self.itr < self.n_train_itr:
            # Prepare video paths for rendering
            options_venv = [{} for _ in range(self.n_envs)]
            if self.itr % self.render_freq == 0 and self.render_video:
                for env_ind in range(self.n_render):
                    options_venv[env_ind]["video_path"] = os.path.join(
                        self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
                    )

            # Define train or eval mode
            eval_mode = self.itr % self.val_freq == 0 and not self.force_train
            self.model.eval() if eval_mode else self.model.train()
            last_itr_eval = eval_mode

            # Reset environment if needed
            firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
            if self.reset_at_iteration or eval_mode or last_itr_eval:
                prev_obs_venv = self.reset_env_all(options_venv=options_venv)
                firsts_trajs[0] = 1
            else:
                firsts_trajs[0] = done_venv

            # Initialize trajectory holders
            obs_trajs = {
                "state": np.zeros(
                    (self.n_steps, self.n_envs, self.n_cond_step, self.obs_dim)
                )
            }
            chains_trajs = np.zeros(
                (
                    self.n_steps,
                    self.n_envs,
                    self.model.ft_denoising_steps + 1,
                    self.horizon_steps,
                    self.action_dim,
                )
            )
            terminated_trajs = np.zeros((self.n_steps, self.n_envs))
            reward_trajs = np.zeros((self.n_steps, self.n_envs))
            if self.save_full_observations:
                obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
                obs_full_trajs = np.vstack(
                    (obs_full_trajs, prev_obs_venv["state"][:, -1][None])
                )

            # Collect trajectories
            for step in range(self.n_steps):
                with torch.no_grad():
                    cond = {
                        "state": torch.from_numpy(prev_obs_venv["state"])
                        .float()
                        .to(self.device)
                    }
                    samples = self.model(
                        cond=cond,
                        deterministic=eval_mode,
                        return_chain=True,
                    )
                    output_venv = samples.trajectories.cpu().numpy()
                    chains_venv = samples.chains.cpu().numpy()
                action_venv = output_venv[:, : self.act_steps]

                # Step environment
                (
                    obs_venv,
                    reward_venv,
                    terminated_venv,
                    truncated_venv,
                    info_venv,
                ) = self.venv.step(action_venv)
                done_venv = terminated_venv | truncated_venv
                if self.save_full_observations:
                    obs_full_venv = np.array(
                        [info["full_obs"]["state"] for info in info_venv]
                    )
                    obs_full_trajs = np.vstack(
                        (obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
                    )
                obs_trajs["state"][step] = prev_obs_venv["state"]
                chains_trajs[step] = chains_venv
                reward_trajs[step] = reward_venv
                terminated_trajs[step] = terminated_venv
                firsts_trajs[step + 1] = done_venv
                prev_obs_venv = obs_venv
                cnt_train_step += self.n_envs * self.act_steps if not eval_mode else 0

            # Summarize episode rewards
            episodes_start_end = []
            for env_ind in range(self.n_envs):
                env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
                for i in range(len(env_steps) - 1):
                    start = env_steps[i]
                    end = env_steps[i + 1]
                    if end - start > 1:
                        episodes_start_end.append((env_ind, start, end - 1))
            if len(episodes_start_end) > 0:
                reward_trajs_split = [
                    reward_trajs[start: end + 1, env_ind]
                    for env_ind, start, end in episodes_start_end
                ]
                num_episode_finished = len(reward_trajs_split)
                episode_reward = np.array(
                    [np.sum(reward_traj) for reward_traj in reward_trajs_split]
                )
                if self.furniture_sparse_reward:
                    episode_best_reward = episode_reward
                else:
                    episode_best_reward = np.array(
                        [
                            np.max(reward_traj) / self.act_steps
                            for reward_traj in reward_trajs_split
                        ]
                    )
                avg_episode_reward = np.mean(episode_reward)
                avg_best_reward = np.mean(episode_best_reward)
                success_rate = np.mean(
                    episode_best_reward >= self.best_reward_threshold_for_success
                )
            else:
                episode_reward = np.array([])
                num_episode_finished = 0
                avg_episode_reward = 0
                avg_best_reward = 0
                success_rate = 0
                log.info("[WARNING] No episode completed within the iteration!")

            # Update models (RL + periodic IL)
            if not eval_mode:
                # Periodic IL update
                if rl_iter % self.rl_per_bc == 0:
                    self.train_il(self.num_il_epochs)
                    log.info(f"IL training: {self.itr} itr | rl_iter: {rl_iter} | num_epochs: {self.num_il_epochs} | IL batch_size: {self.il_batch_size}")

                # RL update
                with torch.no_grad():
                    obs_trajs["state"] = (
                        torch.from_numpy(obs_trajs["state"]).float().to(self.device)
                    )
                    num_split = math.ceil(
                        self.n_envs * self.n_steps / self.logprob_batch_size
                    )
                    obs_ts = [{} for _ in range(num_split)]
                    obs_k = einops.rearrange(
                        obs_trajs["state"],
                        "s e ... -> (s e) ...",
                    )
                    obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
                    for i, obs_t in enumerate(obs_ts_k):
                        obs_ts[i]["state"] = obs_t
                    values_trajs = np.empty((0, self.n_envs))
                    for obs in obs_ts:
                        values = self.model.critic(obs).cpu().numpy().flatten()
                        values_trajs = np.vstack(
                            (values_trajs, values.reshape(-1, self.n_envs))
                        )
                    chains_t = einops.rearrange(
                        torch.from_numpy(chains_trajs).float().to(self.device),
                        "s e t h d -> (s e) t h d",
                    )
                    chains_ts = torch.split(chains_t, self.logprob_batch_size, dim=0)
                    logprobs_trajs = np.empty(
                        (
                            0,
                            self.model.ft_denoising_steps,
                            self.horizon_steps,
                            self.action_dim,
                        )
                    )
                    for obs, chains in zip(obs_ts, chains_ts):
                        logprobs = self.model.get_logprobs(obs, chains).cpu().numpy()
                        logprobs_trajs = np.vstack(
                            (
                                logprobs_trajs,
                                logprobs.reshape(-1, *logprobs_trajs.shape[1:]),
                            )
                        )
                    if self.reward_scale_running:
                        reward_trajs_transpose = self.running_reward_scaler(
                            reward=reward_trajs.T, first=firsts_trajs[:-1].T
                        )
                        reward_trajs = reward_trajs_transpose.T
                    obs_venv_ts = {
                        "state": torch.from_numpy(obs_venv["state"])
                        .float()
                        .to(self.device)
                    }
                    advantages_trajs = np.zeros_like(reward_trajs)
                    lastgaelam = 0
                    for t in reversed(range(self.n_steps)):
                        if t == self.n_steps - 1:
                            nextvalues = (
                                self.model.critic(obs_venv_ts)
                                .reshape(1, -1)
                                .cpu()
                                .numpy()
                            )
                        else:
                            nextvalues = values_trajs[t + 1]
                        nonterminal = 1.0 - terminated_trajs[t]
                        delta = (
                            reward_trajs[t] * self.reward_scale_const
                            + self.gamma * nextvalues * nonterminal
                            - values_trajs[t]
                        )
                        advantages_trajs[t] = lastgaelam = (
                            delta
                            + self.gamma * self.gae_lambda * nonterminal * lastgaelam
                        )
                    returns_trajs = advantages_trajs + values_trajs

                # Prepare data for updates
                obs_k = {
                    "state": einops.rearrange(
                        obs_trajs["state"],
                        "s e ... -> (s e) ...",
                    )
                }
                chains_k = einops.rearrange(
                    torch.tensor(chains_trajs, device=self.device).float(),
                    "s e t h d -> (s e) t h d",
                )
                returns_k = (
                    torch.tensor(returns_trajs, device=self.device).float().reshape(-1)
                )
                values_k = (
                    torch.tensor(values_trajs, device=self.device).float().reshape(-1)
                )
                advantages_k = (
                    torch.tensor(advantages_trajs, device=self.device)
                    .float()
                    .reshape(-1)
                )
                logprobs_k = torch.tensor(logprobs_trajs, device=self.device).float()

                # RL update
                total_steps = self.n_steps * self.n_envs * self.model.ft_denoising_steps
                clipfracs = []
                grad_norms = []
                for update_epoch in range(self.update_epochs):
                    flag_break = False
                    inds_k = torch.randperm(total_steps, device=self.device)
                    num_batch = max(1, total_steps // self.batch_size)
                    for batch in range(num_batch):
                        start = batch * self.batch_size
                        end = start + self.batch_size
                        inds_b = inds_k[start:end]
                        batch_inds_b, denoising_inds_b = torch.unravel_index(
                            inds_b,
                            (self.n_steps * self.n_envs, self.model.ft_denoising_steps),
                        )
                        obs_b = {"state": obs_k["state"][batch_inds_b]}
                        chains_prev_b = chains_k[batch_inds_b, denoising_inds_b]
                        chains_next_b = chains_k[batch_inds_b, denoising_inds_b + 1]
                        returns_b = returns_k[batch_inds_b]
                        values_b = values_k[batch_inds_b]
                        advantages_b = advantages_k[batch_inds_b]
                        logprobs_b = logprobs_k[batch_inds_b, denoising_inds_b]
                        (
                            pg_loss,
                            entropy_loss,
                            v_loss,
                            clipfrac,
                            approx_kl,
                            ratio,
                            bc_loss,
                            eta,
                        ) = self.model.loss(
                            obs_b,
                            chains_prev_b,
                            chains_next_b,
                            denoising_inds_b,
                            returns_b,
                            values_b,
                            advantages_b,
                            logprobs_b,
                            use_bc_loss=self.use_bc_loss,
                            reward_horizon=self.reward_horizon,
                        )
                        loss = (
                            pg_loss
                            + entropy_loss * self.ent_coef
                            + v_loss * self.vf_coef
                            + bc_loss * self.bc_loss_coeff
                        )
                        clipfracs += [clipfrac]
                        self.actor_optimizer.zero_grad()
                        self.critic_optimizer.zero_grad()
                        if self.learn_eta:
                            self.eta_optimizer.zero_grad()
                        loss.backward()
                        if self.itr >= self.n_critic_warmup_itr:
                            if self.max_grad_norm is not None:
                                grad_norm = torch.nn.utils.clip_grad_norm_(
                                    self.model.actor_ft.parameters(), self.max_grad_norm
                                )
                                grad_norms.append(grad_norm.item())
                            self.actor_optimizer.step()
                            if self.learn_eta and batch % self.eta_update_interval == 0:
                                self.eta_optimizer.step()
                        self.critic_optimizer.step()
                        # log.info(f"approx_kl: {approx_kl}, update_epoch: {update_epoch}, num_batch: {num_batch}")
                        if self.target_kl is not None and approx_kl > self.target_kl:
                            flag_break = True
                            break
                    if flag_break:
                        break

                # Explained variance
                y_pred, y_true = values_k.cpu().numpy(), returns_k.cpu().numpy()
                var_y = np.var(y_true)
                explained_var = (
                    np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
                )
                rl_iter += 1
                log.info(f"RL iter {rl_iter}: explained var {explained_var:8.4f}, eta {eta:8.4f}, approx_kl {approx_kl:8.4f}, ratio {ratio:8.4f}")

            # Render trajectories if applicable
            if (
                self.itr % self.render_freq == 0
                and self.n_render > 0
                and self.traj_plotter is not None
            ):
                self.traj_plotter(
                    obs_full_trajs=obs_full_trajs,
                    n_render=self.n_render,
                    max_episode_steps=self.max_episode_steps,
                    render_dir=self.render_dir,
                    itr=self.itr,
                )

            # Update learning rates and model parameters
            if self.itr >= self.n_critic_warmup_itr:
                self.actor_lr_scheduler.step()
                if self.learn_eta:
                    self.eta_lr_scheduler.step()
            self.critic_lr_scheduler.step()
            self.model.step()
            diffusion_min_sampling_std = self.model.get_min_sampling_denoising_std()

            # Save model
            if self.itr % self.save_model_freq == 0 or self.itr == self.n_train_itr - 1:
                self.save_model()

            # Log results
            run_results.append(
                {
                    "itr": self.itr,
                    "step": cnt_train_step,
                }
            )
            if self.save_trajs:
                run_results[-1]["obs_full_trajs"] = obs_full_trajs
                run_results[-1]["obs_trajs"] = obs_trajs
                run_results[-1]["chains_trajs"] = chains_trajs
                run_results[-1]["reward_trajs"] = reward_trajs
            if self.itr % self.log_freq == 0:
                time = timer()
                run_results[-1]["time"] = time
                if eval_mode:
                    log.info(
                        f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
                    )
                    if self.use_wandb:
                        wandb.log(
                            {
                                "success rate - eval": success_rate,
                                "avg episode reward - eval": avg_episode_reward,
                                "avg best reward - eval": avg_best_reward,
                                "num episode - eval": num_episode_finished,
                            },
                            step=self.itr,
                            commit=False,
                        )
                    run_results[-1]["eval_success_rate"] = success_rate
                    run_results[-1]["eval_episode_reward"] = avg_episode_reward
                    run_results[-1]["eval_best_reward"] = avg_best_reward
                else:
                    log.info(
                        f"{self.itr}: step {cnt_train_step:8d} | loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{time:8.4f}"
                    )
                    if self.use_wandb:
                        wandb.log(
                            {
                                "total env step": cnt_train_step,
                                "loss": loss,
                                "pg loss": pg_loss,
                                "value loss": v_loss,
                                "bc loss": bc_loss,
                                "eta": eta,
                                "approx kl": approx_kl,
                                "ratio": ratio,
                                "clipfrac": np.mean(clipfracs),
                                "explained variance": explained_var,
                                "avg episode reward - train": avg_episode_reward,
                                "num episode - train": num_episode_finished,
                                "diffusion - min sampling std": diffusion_min_sampling_std,
                                "actor lr": self.actor_optimizer.param_groups[0]["lr"],
                                "critic lr": self.critic_optimizer.param_groups[0]["lr"],
                                "RL grad norm": np.mean(grad_norms),
                                "RL lr": self.actor_optimizer.param_groups[0]["lr"],
                            },
                            step=self.itr,
                            commit=True,
                        )
                    run_results[-1]["train_episode_reward"] = avg_episode_reward
                with open(self.result_path, "wb") as f:
                    pickle.dump(run_results, f)
            self.itr += 1