import copy
import torch
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_ as clip_grad

from gen_rl.commons.utils import soft_update, hard_update
from gen_rl.policy.paintGym_actors import ResNet, SimpleCNN, CNN, FlatMLP
from gen_rl.policy.paintGym_critics import ResNet_wobn
from gen_rl.policy.noise_process import GaussianNoise
from gen_rl.policy.models import Actor, Critic, Actor2, Critic2, Actor_TD3, Critic_TD3, Actor_SAC
from gen_rl.policy.wgan import WGAN


class DDPG(object):
    def __init__(self, args: dict = None, **kwargs):
        self._args = args
        self._random_act_fn = args["random_act_fn"]
        self.state_model = self.reward_model = None

        # LR config is from Original TD3 repo!
        # And this is also shared across all envs!
        if self._args["policy_name"].lower() == "vanilla-ddpg":
            act, cr = Actor, Critic
            lr_act, lr_cr = 0.0001, 0.01
            if self._args["env_name"].lower() == "pendulum": lr_cr *= 0.1
        elif self._args["policy_name"].lower() == "ddpg":
            act, cr = Actor2, Critic2
            lr_act, lr_cr = 0.0003, 0.0003
        elif self._args["policy_name"].lower() == "td3":
            act, cr = Actor_TD3, Critic_TD3
            lr_act, lr_cr = 0.0003, 0.0003
        elif self._args["policy_name"].lower() == "sac":
            act, cr = Actor_SAC, Critic_TD3
            lr_act, lr_cr = 0.0003, 0.0003
        else:
            raise ValueError

        if self._args["env_name"].lower() == "paint":
            self._env_decode_fn = args["env_decode_fn"]
            if self._args["paint_type_encoder"].lower() == "simple-cnn":
                act = cr = SimpleCNN
            elif self._args["paint_type_encoder"].lower() == "cnn":
                act = cr = CNN
            elif self._args["paint_type_encoder"].lower() == "flat-mlp":
                act = cr = FlatMLP
            elif self._args["paint_type_encoder"].lower() == "resnet":
                if self._args["if_use_latent_state"] and self._args["if_train_state_model"]:
                    # LR config is from Original TD3 repo!
                    if self._args["policy_name"].lower() == "ddpg":
                        act, cr = ResNet, Critic2
                    elif self._args["policy_name"].lower() == "td3":
                        act, cr = ResNet, Critic_TD3
                    elif self._args["policy_name"].lower() == "sac":
                        act, cr = ResNet, Critic_TD3
                    else:
                        raise ValueError
                else:
                    act, cr = ResNet, ResNet_wobn
            else:
                raise ValueError

            # input: target, canvas, stepnum, coordconv + gbp 3 + 3 + 1 + 2
            # output: (10+3)*5 (action bundle)
            _dim_in = 9
            self.actor = act(num_inputs=_dim_in, depth=18, num_outputs=self._args["action_dim"]).to(
                self._args["device"])
            self.actor_target = act(num_inputs=_dim_in, depth=18, num_outputs=self._args["action_dim"]).to(
                self._args["device"])

            # add the last canvas for better prediction
            if self._args["if_use_prev_state"]: _dim_in += 3
            num_outputs = 64 if self._args["paint_if_patch"] else 1
            self.critic = cr(
                num_inputs=_dim_in, depth=18, num_outputs=num_outputs, if_sigmoid=False, args=self._args
            ).to(self._args["device"])
            self.critic_target = cr(
                num_inputs=_dim_in, depth=18, num_outputs=num_outputs, if_sigmoid=False, args=self._args
            ).to(self._args["device"])

            # Hard syncing!
            hard_update(target=self.critic_target, source=self.critic)
            hard_update(target=self.actor_target, source=self.actor)

            # x,y coordinates for 128 x 128 image
            coord = torch.zeros([1, 2, 128, 128])
            for i in range(128):
                for j in range(128):
                    coord[0, 0, i, j] = i / 127.
                    coord[0, 1, i, j] = j / 127.
            self.coord = coord.to(self._args["device"])

            if self._args["paint_if_gan_reward"]:
                self.reward_model = WGAN(if_patch=self._args["paint_if_patch"], device=self._args["device"])

            self.iqa_model = None
        else:
            self.actor = act(args=self._args).to(self._args["device"])
            self.actor_target = copy.deepcopy(self.actor)

            self.critic = cr(args=self._args).to(self._args["device"])
            self.critic_target = copy.deepcopy(self.critic)

        print(self.actor)
        print(self.critic)
        self.actor_opt = optim.Adam(self.actor.parameters(), lr=lr_act)
        self.critic_opt = optim.Adam(self.critic.parameters(), lr=lr_cr)

        _dim = (self._args["num_envs"], self._args["action_dim"])
        # self.noise_sampler = OUNoise(dim_action=_dim, device=self._args["device"])
        self.noise_sampler = GaussianNoise(_dim, 0.0, args["gaussian_noise_std"], self._args["device"])

        self.total_it = 0
        self._args["policy_noise"] = args["policy_noise"] * args["max_action"]
        self._args["noise_clip"] = args["noise_clip"] * args["max_action"]
        self._args["policy_freq"] = args["policy_freq"]

        if self._args["SAC_if_automatic_entropy_tuning"]:
            self.target_entropy = -torch.prod(torch.Tensor(_dim).to(self._args["device"])).item()
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self._args["device"])
            self.alpha_optim = optim.Adam([self.log_alpha], lr=lr_act)

    def select_action(self, state, epsilon=0.0, if_warmup=False):
        if if_warmup:
            action = np.asarray(self._random_act_fn())
        else:
            self.eval()
            with torch.no_grad():
                action = self._act(state=state, if_target=False)
                # add noise to action
                if epsilon > 0.0 and self._args["policy_name"].lower() != "sac":
                    eps = self.noise_sampler.noise(scale=epsilon)
                    if len(action.shape) == 1:
                        eps = eps[0]
                    action += eps
                    action = action.clamp(-self._args["max_action"], self._args["max_action"])
            action = action.cpu().numpy()
            self.train()
        return action

    def _act(self, state, if_target=False):
        _actor = self.actor_target if if_target else self.actor
        if self._args["env_name"].lower() == "paint":
            state_list = [state[:, :6].float() / 255.0, state[:, 6:7].float() / self._args["max_episode_steps"],
                          self.coord.expand(state.shape[0], 2, 128, 128)]
            state = torch.cat(state_list, 1)
        else:
            if not torch.is_tensor(state):
                state = torch.tensor(state.astype(np.float32)).to(self._args["device"])
        return _actor(state)

    def set_models(self, reward_model, state_model, decompose_obs_fn):
        if reward_model is not None: self.reward_model = reward_model
        if state_model is not None: self.state_model = state_model
        self._decompose_obs = decompose_obs_fn

    def update_policy(self, buffer, batch_size=256):
        if self._args["env_name"].lower() == "paint" and self._args["paint_if_change_lr"]:  # taken from Huang et al.,
            for a_opt, c_opt in zip(self.actor_opt.param_groups, self.critic_opt.param_groups):
                if self._args["global_ts"] < 10000 * self._args["max_episode_steps"]:
                    c_opt["lr"], a_opt["lr"] = 3e-4, 1e-3
                elif self._args["global_ts"] < 20000 * self._args["max_episode_steps"]:
                    c_opt["lr"], a_opt["lr"] = 1e-4, 3e-4
                else:
                    c_opt["lr"], a_opt["lr"] = 3e-5, 1e-4

        if self._args["policy_name"].lower() in ["vanilla-ddpg", "ddpg"]:
            return self._update_ddpg(buffer, batch_size)
        elif self._args["policy_name"].lower() in ["td3", "sac"]:
            return self._update_td3_sac(buffer, batch_size)

    def update_models(self, buffer, batch_size=256):
        if self._args["env_name"].lower() == "paint":
            self._update_state_reward_models(buffer, batch_size)
            dict_loss = self._update_gan(buffer, batch_size)
            if self._args["if_train_models"]:
                _dict = self._update_state_reward_models(buffer, batch_size)
                dict_loss.update(_dict)
            return dict_loss
        else:
            if not self._args["if_use_act_val_fn"] and self._args["if_train_models"]:
                return self._update_state_reward_models(buffer, batch_size)
            else:
                return {}

    def _update_state_reward_models(self, buffer, batch_size=256):
        if self._args["if_train_state_model"]:
            self.state_model.train()
        if self._args["if_train_reward_model"]:
            self.reward_model.train()
        obses_t, actions, rewards, obses_tp1, dones = buffer.sample(batch_size)
        if self._args["if_use_latent_state"]:
            obses_t, actions = obses_t[1], actions[1]

        if self._args["if_train_reward_model"]:
            self.reward_model.zero_grad()
            pred = self.reward_model((obses_t, actions))
            r_loss = self.reward_model.criterion(rewards, pred)
            r_loss.backward()
            self.reward_model.optim.step()

        if self._args["if_train_state_model"]:
            self.state_model.zero_grad()
            if self._args["env_name"].lower() == "paint":
                obses_t = obses_t[:, :3, ...].float() / 255.0
                obses_tp1 = obses_tp1[:, :3, ...].float() / 255.0
                pred = self.state_model(obses_t, actions)
                s_loss = self.state_model.criterion(obses_tp1, *pred)
            else:
                pred = self.state_model((obses_t, actions))
                s_loss = self.state_model.criterion(obses_tp1, pred)
            s_loss.backward()
            self.state_model.optim.step()

        if self._args["if_train_state_model"]:
            self.state_model.eval()
        if self._args["if_train_reward_model"]:
            self.reward_model.eval()

        res = {
            "dynamics-model-loss": s_loss.item() if self._args["if_train_state_model"] else 0.0,
            "reward-model-loss": r_loss.item() if self._args["if_train_reward_model"] else 0.0,
        }
        return res

    def _update_gan(self, buffer, batch_size=256):
        if self._args["paint_if_gan_reward"]:
            self.reward_model.train()
            _, _, _, obses_tp1, _ = buffer.sample(batch_size)
            canvas, gt = obses_tp1[:, :3].float() / 255.0, obses_tp1[:, 3: 6].float() / 255.0
            fake, real, penal = self.reward_model.update(canvas, gt)
            self.reward_model.eval()
            res = {"d-fake": fake.item(), "d-real": real.item(), "d-penal": penal.item(), }
        else:
            res = {"d-fake": 0.0, "d-real": 0.0, "d-penal": 0.0, }
        return res

    def _eval_val_reward(self, obs_t, a_t, state=None, target=False, _if_Q1=False, obs_tm1=None, a_tm1=None):
        if self._args["policy_name"].lower() == "td3" and _if_Q1:
            _critic = self.critic_target.Q1 if target else self.critic.Q1
        else:
            _critic = self.critic_target if target else self.critic

        reward = torch.tensor(0.0, device=self._args["device"]).float()
        if self._args["env_name"].lower() == "paint":
            # get canvas, ground-truth, time from merged state (gt,canvas,t)
            canvas_t, gt, T = obs_t[:, :3].float() / 255.0, obs_t[:, 3: 6].float() / 255.0, obs_t[:, 6: 7].float()
            canvas_tp1 = self._env_decode_fn(a_t, canvas_t)
            """ Note: in env.py, we multiply 255 to recover the pixel values but here, we use the normalised values! """
            if self._args["if_use_act_val_fn"]:  canvas_tp1 = canvas_tp1.detach()

            # compute bg gan reward based on difference between wgan distances (L_t - L_t+1)
            if self._args["paint_type_diff_reward"] == "original":
                # (L_t - L_t-1), right... so Bug??
                if self._args["paint_if_gan_reward"]:
                    reward = self.reward_model.cal_reward(canvas_tp1, gt) - self.reward_model.cal_reward(canvas_t, gt)
                if self.iqa_model is not None:
                    reward = reward + (self.iqa_model(canvas_tp1, gt) - self.iqa_model(canvas_t, gt))
            elif self._args["paint_type_diff_reward"] == "fixed":
                # (L_t - L_t+1)
                if self._args["paint_if_gan_reward"]:
                    reward = self.reward_model.cal_reward(canvas_t, gt) - self.reward_model.cal_reward(canvas_tp1, gt)
                if self.iqa_model is not None:
                    reward = reward + (self.iqa_model(canvas_t, gt) - self.iqa_model(canvas_tp1, gt))
            elif self._args["paint_type_diff_reward"] == "one-step":
                # SPIRAL's reward formulation
                if self._args["paint_if_gan_reward"]:
                    reward = self.reward_model.cal_reward(canvas_tp1, gt)
                if self.iqa_model is not None:
                    reward += self.iqa_model(canvas_tp1, gt)
            else:
                raise ValueError

            if self._args["if_train_state_model"]:
                canvas_tp1, z = self.state_model(canvas_t, a_t, if_return_latent=True)

            # Get new merged state
            c = self.coord.expand(obs_t.shape[0], 2, 128, 128)
            if self._args["if_use_act_val_fn"]:
                # import pudb; pudb.start()
                _in = (torch.cat([canvas_t, gt, (T + 1) / self._args["max_episode_steps"], c], 1), a_t)  # Q(s, a)
            else:
                state_list = [canvas_tp1, gt, (T + 1) / self._args["max_episode_steps"], c]
                if self._args["if_use_prev_state"]:  # Diff in prev and current canvases describes the action
                    state_list += [canvas_t]
                _in = torch.cat(state_list, 1)  # V(s')
            # import pudb; pudb.start()
            if self._args["if_train_state_model"]:
                if self._args["if_use_latent_state"]:
                    val = _critic(z)
                else:
                    val = _critic(_in, z)
            else:
                val = _critic(_in)
        else:
            if self._args["if_use_act_val_fn"]:
                val = _critic(obs_t, a_t)
            else:
                if self._args["if_train_reward_model"]:
                    _in = (obs_t, a_t)
                else:
                    _in = (state, a_t)
                reward = self.reward_model(_in)
                obses_tp1 = self.state_model(_in, if_return_latent=self._args["if_use_latent_state"])
                if self._args["if_use_latent_state"]:
                    obs_t = self.state_model((obs_tm1, a_tm1), if_return_latent=True)

                if self._args["if_use_prev_state"]:  # Diff in prev and current canvases describes the action
                    _in = torch.cat([obs_t, obses_tp1], 1)
                else:
                    _in = obses_tp1
                val = _critic(_in, a_t)
        return val, reward

    def _update_ddpg(self, buffer, batch_size=256):
        obses_t, actions, rewards, obses_tp1, dones = buffer.sample(batch_size)
        if self._args["if_use_latent_state"]:
            (obses_tm1, obses_t), (actions_tm1, actions) = obses_t, actions
        state = next_state = None
        if self._args["env_name"] == "Pendulum":
            if not self._args["if_use_act_val_fn"] and not self._args["env_name"].startswith("mujoco-single"):
                state, next_state, obses_t, obses_tp1 = self._decompose_obs(self._args, obses_t, obses_tp1)

        # import pudb; pudb.start()
        with torch.no_grad():
            next_action = self._act(state=obses_tp1, if_target=True)
            if self._args["if_use_latent_state"]:
                val_tp1, rewards_tp1 = self._eval_val_reward(
                    obses_tp1, next_action, next_state, True, obs_tm1=obses_t, a_tm1=actions)
            else:
                val_tp1, rewards_tp1 = self._eval_val_reward(obses_tp1, next_action, next_state, True)
            if self._args["if_use_next_reward"]:
                val_tp1 = rewards_tp1 + (1 - dones) * self._args["discount"] * val_tp1
            if self._args["env_name"].lower() == "paint":
                # In PaintGym, we need the reward from the discriminator
                target = (1 - dones) * self._args["discount"] * val_tp1
            else:
                target = rewards + (1 - dones) * self._args["discount"] * val_tp1

        if self._args["if_use_latent_state"]:
            val_t, rewards = self._eval_val_reward(
                obs_t=obses_t, a_t=actions, state=state, target=False, obs_tm1=obses_tm1, a_tm1=actions_tm1)
        else:
            val_t, rewards = self._eval_val_reward(obs_t=obses_t, a_t=actions, state=state, target=False)
        if self._args["env_name"].lower() == "paint":
            target += rewards.detach()  # In PaintGym, we need the reward from the discriminator
        value_loss = F.mse_loss(val_t, target)
        self.critic_opt.zero_grad()
        value_loss.backward()
        clip_grad(self.critic.parameters(), 1.0)  # gradient clipping
        self.critic_opt.step()

        # Compute actor loss
        action = self._act(state=obses_t, if_target=False)
        if self._args["if_use_latent_state"]:
            val, rewards_actor = self._eval_val_reward(
                obs_t=obses_t.detach(), a_t=action, state=state, target=False, obs_tm1=obses_tm1, a_tm1=actions_tm1)
        else:
            val, rewards_actor = self._eval_val_reward(obs_t=obses_t.detach(), a_t=action, state=state, target=False)
        if self._args["if_actor_reward"]:
            val = rewards_actor + self._args["discount"] * val

        policy_loss = - val.mean()
        self.actor_opt.zero_grad()
        policy_loss.backward()
        clip_grad(self.actor.parameters(), 1.0)  # gradient clipping
        self.actor_opt.step()

        soft_update(target=self.critic_target, source=self.critic, tau=self._args["tau"])
        soft_update(target=self.actor_target, source=self.actor, tau=self._args["tau"])

        res = {
            "policy_loss": policy_loss.data.cpu().numpy(),
            "value_loss": value_loss.data.cpu().numpy(),
            "Q(s',a')": val_tp1.mean().item(),
            "Q(s,a)": val_t.mean().item(),
            "Q(s,pi(s))": val.mean().item(),
        }
        if self._args["env_name"].lower() == "paint": res["rewards"] = rewards.mean().item()
        return res

    def _update_td3_sac(self, buffer, batch_size=256):
        self.total_it += 1

        obses_t, actions, rewards, obses_tp1, dones = buffer.sample(batch_size)
        if self._args["if_use_latent_state"]:
            (obses_tm1, obses_t), (actions_tm1, actions) = obses_t, actions

        # import pudb; pudb.start()
        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (torch.randn_like(actions) * self._args["policy_noise"])
            noise = noise.clamp(-self._args["noise_clip"], self._args["noise_clip"])
            next_action = self._act(state=obses_tp1, if_target=True) + noise
            next_action = next_action.clamp(-self._args["max_action"], self._args["max_action"])

            if self._args["if_use_latent_state"]:
                val_tp1, rewards_tp1 = self._eval_val_reward(
                    obses_tp1, next_action, target=True, obs_tm1=obses_t, a_tm1=actions)
            else:
                val_tp1, rewards_tp1 = self._eval_val_reward(obses_tp1, next_action, target=True)

            val_tp1 = torch.min(*val_tp1)
            target = rewards + (1 - dones) * self._args["discount"] * val_tp1

        if self._args["if_use_latent_state"]:
            val_t, rewards = self._eval_val_reward(
                obs_t=obses_t, a_t=actions, target=False, obs_tm1=obses_tm1, a_tm1=actions_tm1)
        else:
            val_t, rewards = self._eval_val_reward(obs_t=obses_t, a_t=actions, target=False)

        value_loss = F.mse_loss(val_t[0], target) + F.mse_loss(val_t[1], target)
        self.critic_opt.zero_grad()
        value_loss.backward()
        clip_grad(self.critic.parameters(), 1.0)  # gradient clipping
        self.critic_opt.step()

        policy_loss = 0.0
        if self._args["policy_name"].lower() == "sac":
            action = self._act(state=obses_t, if_target=False)

            if self._args["if_use_latent_state"]:
                val, rewards_actor = self._eval_val_reward(
                    obs_t=obses_t.detach(), a_t=action, target=False, obs_tm1=obses_tm1, a_tm1=actions_tm1)
            else:
                val, rewards_actor = self._eval_val_reward(obs_t=obses_t.detach(), a_t=action, target=False)

            val = torch.min(*val)
            if self._args["env_name"].lower() == "paint" and not self._args["if_use_act_val_fn"]:
                val = rewards_actor + self._args["discount"] * val

            policy_loss = - val.mean()
            self.actor_opt.zero_grad()
            policy_loss.backward()
            clip_grad(self.actor.parameters(), 1.0)  # gradient clipping
            self.actor_opt.step()

            soft_update(target=self.critic_target, source=self.critic, tau=self._args["tau"])
            soft_update(target=self.actor_target, source=self.actor, tau=self._args["tau"])

            policy_loss = policy_loss.item()

            if self._args["SAC_if_automatic_entropy_tuning"]:
                alpha_loss = -(self.log_alpha * (self.actor.log_prob + self.target_entropy).detach()).mean()

                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()

                self.alpha = self.log_alpha.exp()
        else:
            # Delayed policy updates
            if self.total_it % self._args["policy_freq"] == 0:
                action = self._act(state=obses_t, if_target=False)

                if self._args["if_use_latent_state"]:
                    val, rewards_actor = self._eval_val_reward(
                        obs_t=obses_t.detach(), a_t=action, target=False, obs_tm1=obses_tm1, a_tm1=actions_tm1,
                        _if_Q1=True)
                else:
                    val, rewards_actor = self._eval_val_reward(obs_t=obses_t.detach(), a_t=action, target=False,
                                                               _if_Q1=True)

                if self._args["env_name"].lower() == "paint" and not self._args["if_use_act_val_fn"]:
                    val = rewards_actor + self._args["discount"] * val

                policy_loss = - val.mean()
                self.actor_opt.zero_grad()
                policy_loss.backward()
                clip_grad(self.actor.parameters(), 1.0)  # gradient clipping
                self.actor_opt.step()

                soft_update(target=self.critic_target, source=self.critic, tau=self._args["tau"])
                soft_update(target=self.actor_target, source=self.actor, tau=self._args["tau"])

                policy_loss = policy_loss.item()

        res = {
            "policy_loss": policy_loss,
            "value_loss": value_loss.data.cpu().numpy(),
        }
        return res

    def reset(self, id):
        self.noise_sampler.reset(_id=id)

    def save(self, filename, **kwargs):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_opt.state_dict(), filename + "_critic_opt")

        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_opt.state_dict(), filename + "_actor_opt")

        if self.reward_model is not None:
            torch.save(self.reward_model.state_dict(), filename + "_reward_model")

        if self.state_model is not None:
            torch.save(self.state_model.state_dict(), filename + "_state_model")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_opt.load_state_dict(torch.load(filename + "_critic_opt"))
        self.critic_target = copy.deepcopy(self.critic)

        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_opt.load_state_dict(torch.load(filename + "_actor_opt"))
        self.actor_target = copy.deepcopy(self.actor)

        if self.reward_model is not None:
            self.reward_model.load_state_dict(torch.load(filename + "_reward_model"))

        if self.state_model is not None:
            self.state_model.load_state_dict(torch.load(filename + "_state_model"))

    def eval(self):
        self.actor.eval()
        self.actor_target.eval()
        self.critic.eval()
        self.critic_target.eval()

    def train(self):
        self.actor.train()
        self.actor_target.train()
        self.critic.train()
        self.critic_target.train()
