import os
import time
import pickle
from collections import defaultdict
import math

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from divmorph.config import cfg
from divmorph.envs.vec_env.vec_video_recorder import VecVideoRecorder
from divmorph.utils import file as fu
from divmorph.utils import model as mu
from divmorph.utils import optimizer as ou
from divmorph.utils.meter import TrainMeter

from .buffer import Buffer
from .envs import get_ob_rms
from .envs import make_vec_envs
from .envs import set_ob_rms
from .inherit_weight import restore_from_checkpoint
from .svd_model import ActorCritic
from .svd_model import Agent

from .moe_transformer import Gate
from .svd_module import SVDLinearOptimized

class PPO_SVD:
    def __init__(self, print_model=True):
        self.envs = make_vec_envs()
        self.file_prefix = cfg.ENV_NAME

        self.device = torch.device(cfg.DEVICE)

        self.actor_critic = globals()[cfg.MODEL.ACTOR_CRITIC](
            self.envs.observation_space, self.envs.action_space
        )
        if cfg.PPO.CHECKPOINT_PATH:        
            ob_rms = restore_from_checkpoint(self.actor_critic)
            set_ob_rms(self.envs, ob_rms)

        if print_model:
            print("Num params: {}".format(mu.num_params(self.actor_critic)))

        self.actor_critic.to(self.device)
        self.agent = Agent(self.actor_critic)

        self.buffer = Buffer(self.envs.observation_space, self.envs.action_space.shape)
        
        gene_param = []
        other_param = []
        tailor_param = []
        for name, param in self.actor_critic.named_parameters():
            if 'gene' in name:
                gene_param.append(param)
            elif 'other' in name:
                tailor_param.append(param)
            else:
                other_param.append(param)
        self.optimizer = optim.Adam(
            [{'params': gene_param, 'lr': cfg.PPO.GENE_LR}, 
             {'params': tailor_param, 'lr': cfg.PPO.TAILOR_LR},
             {'params': other_param, 'lr': cfg.PPO.BASE_LR}],
            eps=cfg.PPO.EPS, weight_decay=cfg.PPO.WEIGHT_DECAY
        )

        self.lr_scale = [1. for _ in self.optimizer.param_groups]

        self.train_meter = TrainMeter()

        for name, param in self.actor_critic.state_dict().items():
            if "log_std" in name:
                self.log_std_param = name
                break

        self.fps = 0

    def train(self):
        self.save_sampled_agent_seq(0)
        obs = self.envs.reset()
        self.buffer.to(self.device)
        self.start = time.time()

        unimal_ids = self.envs.get_unimal_idx()

        for key in obs:
            print(key, obs[key].size())

        if cfg.PPO.MAX_ITERS > 1000:
            self.stat_save_freq = 100
        else:
            self.stat_save_freq = 10

        for cur_iter in range(cfg.PPO.START_ITER, cfg.PPO.MAX_ITERS):
            
            if cfg.PPO.EARLY_EXIT and cur_iter >= cfg.PPO.EARLY_EXIT_MAX_ITERS:
                break
            
            lr, lr_gene, lr_tailor = ou.get_iter_lr_gene(cur_iter)
            ou.set_lr_gene(self.optimizer, lr, lr_gene, lr_tailor, self.lr_scale)

            for step in range(cfg.PPO.TIMESTEPS):
                unimal_ids = self.envs.get_unimal_idx()

                val, act, logp, dropout_mask_v, dropout_mask_mu, indices_v, indices_mu = self.agent.act(obs, unimal_ids=unimal_ids, task=cfg.ENV.TASK_NAME)
                next_obs, reward, done, infos = self.envs.step(act)

                self.train_meter.add_ep_info(infos)

                masks = torch.tensor(
                    [[0.0] if done_ else [1.0] for done_ in done],
                    dtype=torch.float32,
                    device=self.device,
                )
                timeouts = torch.tensor(
                    [[0.0] if "timeout" in info.keys() else [1.0] for info in infos],
                    dtype=torch.float32,
                    device=self.device,
                )

                self.buffer.insert(obs, act, logp, val, reward, masks, timeouts, dropout_mask_v, dropout_mask_mu, unimal_ids)
                obs = next_obs

            if cur_iter % 1 == 0:
                indices_mu = indices_mu[0].detach().cpu().tolist()

                save_path = os.path.join(cfg.OUT_DIR, "indices.txt")
                with open(save_path, "a") as f:
                    f.write(f"=== Iter {cur_iter} ===\n")
                    for uid, im in zip(unimal_ids, indices_mu):
                        f.write(f"unimal_id: {uid}\tindices_mu: {im}\n")

            unimal_ids = self.envs.get_unimal_idx()

            next_val = self.agent.get_value(obs, unimal_ids=unimal_ids, task=cfg.ENV.TASK_NAME)
            self.buffer.compute_returns(next_val)
            self.train_on_batch(cur_iter)
            self.save_sampled_agent_seq(cur_iter)

            self.train_meter.update_mean()
            if (
                cur_iter >= 0
                and cur_iter % cfg.LOG_PERIOD == 0
                and cfg.LOG_PERIOD > 0
            ):
                self._log_stats(cur_iter)

                file_name = "{}_results.json".format(self.file_prefix)
                path = os.path.join(cfg.OUT_DIR, file_name)
                self._log_fps(cfg.PPO.MAX_ITERS - 1, log=False)
                stats = self.train_meter.get_stats()
                stats["fps"] = self.fps
                fu.save_json(stats, path)
                print(f"OUT_DIR {cfg.OUT_DIR}")
            
            if cur_iter % 100 == 0:
                self.save_model(cur_iter)

            self.save_model(-2)

        print("Finished Training: {}".format(self.file_prefix))

    def train_on_batch(self, cur_iter):
        adv = self.buffer.ret - self.buffer.val
        adv = (adv - adv.mean()) / (adv.std() + 1e-5)
        for i in range(cfg.PPO.EPOCHS):
            batch_sampler = self.buffer.get_sampler(adv)

            for j, batch in enumerate(batch_sampler):
                val, _, logp, ent, _, _, indices_v, indices_mu = self.actor_critic(batch["obs"], batch["act"], \
                    dropout_mask_v=batch['dropout_mask_v'], \
                    dropout_mask_mu=batch['dropout_mask_mu'], \
                    unimal_ids=batch['unimal_ids'],
                    task_name=cfg.ENV.TASK_NAME )
                
                clip_ratio = cfg.PPO.CLIP_EPS
                ratio = torch.exp(logp - batch["logp_old"])
                approx_kl = (batch["logp_old"] - logp).mean().item()

                if cfg.PPO.KL_TARGET_COEF is not None and approx_kl > cfg.PPO.KL_TARGET_COEF * 0.01:
                    self.train_meter.add_train_stat("approx_kl", approx_kl)
                    log_str = f'early stop iter {cur_iter} at epoch {i + 1}/{cfg.PPO.EPOCHS}, batch {j + 1} with approx_kl {approx_kl}'
                    print(log_str)
                    with open(cfg.OUT_DIR + "/log.txt", "a", encoding="utf-8") as f:
                        f.write(log_str + "\n")
                    return

                surr1 = ratio * batch["adv"]

                surr2 = torch.clamp(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio)
                clip_frac = (ratio != surr2).float().mean().item()
                surr2 *= batch["adv"]

                pi_loss = -torch.min(surr1, surr2).mean()
                
                if cfg.PPO.USE_CLIP_VALUE_FUNC:
                    val_pred_clip = batch["val"] + (val - batch["val"]).clamp(
                        -clip_ratio, clip_ratio
                    )
                    val_loss = (val - batch["ret"]).pow(2)
                    val_loss_clip = (val_pred_clip - batch["ret"]).pow(2)
                    val_loss = 0.5 * torch.max(val_loss, val_loss_clip).mean()
                else:
                    val_loss = 0.5 * (batch["ret"] - val).pow(2).mean()

                self.optimizer.zero_grad()

                loss = val_loss * cfg.PPO.VALUE_COEF
                loss += pi_loss
                loss += -ent * cfg.PPO.ENTROPY_COEF
                loss.backward()

                norm = nn.utils.clip_grad_norm_(
                    self.actor_critic.parameters(), cfg.PPO.MAX_GRAD_NORM
                )
                self.train_meter.add_train_stat("grad_norm", norm.item())

                log_std = (
                    self.actor_critic.state_dict()[self.log_std_param].cpu().numpy()[0]
                )
                std = np.mean(np.exp(log_std))
                self.train_meter.add_train_stat("std", float(std))

                self.train_meter.add_train_stat("approx_kl", approx_kl)
                self.train_meter.add_train_stat("pi_loss", pi_loss.item())
                self.train_meter.add_train_stat("val_loss", val_loss.item())
                self.train_meter.add_train_stat("ratio", ratio.mean().item())
                self.train_meter.add_train_stat("surr1", surr1.mean().item())
                self.train_meter.add_train_stat("surr2", surr2.mean().item())
                self.train_meter.add_train_stat("clip_frac", clip_frac)

                self.optimizer.step()


    def save_model(self, cur_iter, path=None):
        if not path:
            path = os.path.join(cfg.OUT_DIR, self.file_prefix + ".pt")
        torch.save([self.actor_critic.state_dict(), get_ob_rms(self.envs)], path)
        checkpoint_path = os.path.join(cfg.OUT_DIR, f"checkpoint_{cur_iter}.pt")
        torch.save([self.actor_critic.state_dict(), get_ob_rms(self.envs)], checkpoint_path)

    def _log_stats(self, cur_iter):
        self._log_fps(cur_iter)
        self.train_meter.log_stats()

    def _log_fps(self, cur_iter, log=True):
        env_steps = self.env_steps_done(cur_iter)
        end = time.time()
        self.fps = int(env_steps / (end - self.start))
        if log:
            log_msg = "Updates {}, num timesteps {}, FPS {}".format(
                cur_iter, env_steps, self.fps
            )
            print(log_msg)
            with open(cfg.OUT_DIR + "/log.txt", "a") as f:
                f.write(log_msg + "\n")

    def env_steps_done(self, cur_iter):
        return (cur_iter + 1) * cfg.PPO.NUM_ENVS * cfg.PPO.TIMESTEPS


    def save_sampled_agent_seq(self, cur_iter):
        num_agents = len(cfg.ENV.WALKERS)

        if num_agents <= 1:
            return

        if cfg.ENV.TASK_SAMPLING == "uniform_random_strategy":
            ep_lens = [1000] * num_agents
        elif cfg.ENV.TASK_SAMPLING == "balanced_replay_buffer":
            if cur_iter < 30:
                ep_lens = [1000] * num_agents
            else:
                if cfg.TASK_SAMPLING.AVG_TYPE == "ema":
                    ep_lens = [
                        np.mean(self.train_meter.agent_meters[agent].ep_len_ema)
                        for agent in cfg.ENV.WALKERS
                    ]
                elif cfg.TASK_SAMPLING.AVG_TYPE == "moving_window":
                    ep_lens = [
                        np.mean(self.train_meter.agent_meters[agent].ep_len)
                        for agent in cfg.ENV.WALKERS
                    ]

        probs = [1000.0 / l if l > 0 else 100 for l in ep_lens]

        probs = np.power(probs, cfg.TASK_SAMPLING.PROB_ALPHA)
        probs = [p / sum(probs) for p in probs]

        avg_ep_len = np.mean([
            np.mean(self.train_meter.agent_meters[agent].ep_len)
            for agent in cfg.ENV.WALKERS
        ])
        if np.isnan(avg_ep_len):
            avg_ep_len = 100
        ep_per_env = cfg.PPO.TIMESTEPS / avg_ep_len
        size = int(ep_per_env * cfg.PPO.NUM_ENVS * 50)
        task_list = np.random.choice(range(0, num_agents), size=size, p=probs)
        task_list = [int(_) for _ in task_list]
        path = os.path.join(cfg.OUT_DIR, "sampling.json")
        fu.save_json(task_list, path)
