import numpy as np
import torch
import torch.nn as nn
from mpc_model.utils import Model, MLPRegression, MLP
from torch.utils.data import Dataset, DataLoader
import wandb
import math

import models.model_utils as u
from models.model_utils import device
import models.networks as nets
from models.hps import HPS
from models.patd3 import TD3

from embedding import ActionRepresentation_vae as hyar


class world_model(nn.Module):
    def __init__(self, args):
        super().__init__()
        # self.args = args
        self.timestep = 0
        self.max_timestep = args.max_timesteps

        self.env = args.env
        self.inp_dim = args.state_dim
        self.s_dim = args.state_dim

        if args.onepa: 
            self.inp_dim += args.k_dim + args.z_dim
        else:
            self.inp_dim += args.action_dim

        self.embed = args.embed
        if self.embed:
            self.action_rep = hyar.Action_representation(
                state_dim=args.state_dim,
                action_dim=args.k_dim,
                parameter_action_dim=args.z_dim,
                reduced_action_dim=args.discrete_emb_dim,
                reduce_parameter_action_dim=args.parameter_emb_dim
                )
            self.ek_dim = args.discrete_emb_dim
            self.ez_dim = args.parameter_emb_dim
            self.inp_dim = self.s_dim + self.ek_dim + self.ez_dim
            
        if args.env in ['Platform-v0', 'Goal-v0', 'hard_goal-v0']:
            s_tanh = True
        elif args.env in ['simple_catch-v0', 'simple_move_4_direction_v1-v0']:
            s_tanh = False
        else:
            raise f"ENV {args.env} not implemented yet"
        
        # print(self.s_dim, type(self.s_dim), self.inp_dim, type(self.inp_dim))
        self._dyanmics = nets.TanhGaussianPolicy(hidden_sizes=[args.dm_layers for _ in range(2)], 
                                                oup_dim=self.s_dim, 
                                                inp_dim=self.inp_dim,
                                                tanh=s_tanh).to(device)
        
        if args.env == 'Platform-v0':
            r_tanh = True
        elif args.env in ['Goal-v0', 'simple_catch-v0', 'hard_goal-v0', 'simple_move_4_direction_v1-v0']:
            r_tanh = False
        else:
            raise f"ENV {args.env} not implemented yet"
        self._reward = nets.TanhGaussianPolicy(hidden_sizes=[args.r_layers for _ in range(2)], 
                                                oup_dim=1, 
                                                inp_dim=self.inp_dim,
                                                tanh=r_tanh).to(device)
        self._continue = nets.TanhGaussianPolicy(hidden_sizes=[args.dm_layers for _ in range(2)], 
                                                oup_dim=2, 
                                                inp_dim=self.inp_dim,
                                                tanh=False).to(device)

        self._Q1, self._Q2 = nets.q(args).to(device), nets.q(args).to(device)

        self.policy_type = args.policy_type
        if args.policy_type == "hps":
            self._agent = HPS(args).to(device)

            net_size = 300
            self.qf1 = nets.FlattenMlp(
                hidden_sizes=[net_size, net_size],
                input_size=self.inp_dim,
                output_size=1,
            )  # qnetwork1
            self.qf2 = nets.FlattenMlp(
                hidden_sizes=[net_size, net_size],
                input_size=self.inp_dim,
                output_size=1,
            )  # qnetwork2
            self.vf = nets.FlattenMlp(
                hidden_sizes=[net_size, net_size],
                input_size=self.s_dim,
                output_size=1,
            )  # qnetwork3?
            self.target_vf = self.vf.copy()

            qf_lr = 3e-4
            vf_lr = 3e-4
            policy_lr = 3e-4
            self.policy_mean_reg_weight=1e-3
            self.policy_std_reg_weight=1e-3
            self.policy_pre_activation_weight=0.
            self.soft_target_tau=5e-3

            optimizer_class = torch.optim.Adam
            self.qf1_optimizer = optimizer_class(
                self.qf1.parameters(),
                lr=qf_lr,
            )
            self.qf2_optimizer = optimizer_class(
                self.qf2.parameters(),
                lr=qf_lr,
            )
            self.vf_optimizer = optimizer_class(
                self.vf.parameters(),
                lr=vf_lr,
            )
            self.dpolicy_optimizer = optimizer_class(
                self._agent.discrete_policy.parameters(),
                lr=policy_lr,
            )
            self.cpolicy_optimizer = optimizer_class(
                self._agent.continuous_policy.parameters(),
                lr=policy_lr,
            )

            self.discount = args.mpc_gamma
            self.vf_criterion = nn.MSELoss()

        elif args.policy_type == "patd3":
            self._agent = TD3(args).to(device)

    def linear_temp(self):
        return max((1 - (self.timestep / 20_000)) * 4.5 + 0.5, 0.5)  # (0, 1) -> (5, 0.5)
    
    def pi(self, s, std, reparameterize, return_log_prob, deterministic):
        '''
        inp:
        s: [N_policy_traj, s_dim]
        oup:
        dpolicy_outputs: [N_policy_traj, k_dim]
        cpolicy_outputs: [N_policy_traj, z_dim]

                        reparameterize  return_log_prob deterministic
        train           True            True            False
        train_plan      False           True            False    
        evaluate_plan   -               -               True
        TD_target       -               -               True
        estimate_value  -               -               True
        '''
        dpolicy_outputs, cpolicy_outputs = self._agent(s, reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic, temperature=self.linear_temp())

        return dpolicy_outputs, cpolicy_outputs[0]
    
    def track_q_grad(self, enable=True):
        """Utility function. Enables/disables gradient tracking of Q-networks."""
        for m in [self._Q1, self._Q2]:
            u.set_requires_grad(m, enable)
    
    def Q(self, s, a):
        """Predict state-action value (Q)."""
        x = torch.cat([s, a], dim=-1)
        return self._Q1(x), self._Q2(x)
    
    def next(self, s, a, reparameterize, return_log_prob, deterministic):
        '''
                        reparameterize  return_log_prob deterministic
            train           True            True            False
            train_plan      False           True            False    
            evaluate_plan   -               -               True
            estimate_value  -               -               True
        '''
        x = torch.cat([s, a], dim=-1)
        # print(x.shape, s.shape, a.shape)
        
        s = self._dyanmics(x, reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)[0]

        r = self._reward(x, reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)[0]
        if self.env == 'Platform-v0':
            r = (r + 1) / 2

        c = self._continue(x, reparameterize=reparameterize, return_log_prob=return_log_prob, deterministic=deterministic)[0]
        # c = c.argmax(-1).unsqueeze(-1)  # 0: terminate, 1:continue

        return s, r, c
    
    def decode_embed(self, state, ek, ez):
        k = self.action_rep.select_discrete_action(ek)

        true_next_parameter_emb = u.true_parameter_action(ez, self.c_rate)
        next_discrete_emb_1 = self.action_rep.get_embedding(k).cpu().view(-1).data.numpy()

        allz = self.action_rep.select_parameter_action(state, true_next_parameter_emb, next_discrete_emb_1)

        return k, allz
    
    def train_policy(self, obs, dactions, cactions, rewards, next_obs, conts):
        if self.policy_type == "hps":
            return self.train_hps(obs, dactions, cactions, rewards, next_obs, conts)

        elif self.policy_type == "patd3":
            return self._agent.trainme(obs, dactions, cactions, rewards, next_obs, conts)

    def train_hps(self, obs, dactions, cactions, rewards, next_obs, conts):

        dpolicy_outputs, cpolicy_outputs = self._agent(obs, reparameterize=True, return_log_prob=True, deterministic=False, temperature=self.linear_temp())

        new_cactions, policy_mean, policy_log_std, log_pi = cpolicy_outputs[:4]

        q1_pred = self.qf1(obs, dactions, cactions)
        q2_pred = self.qf2(obs, dactions, cactions)

        v_pred = self.vf(obs)
        # get targets for use in V and Q updates
        with torch.no_grad():
            target_v_values = self.target_vf(next_obs)

        # KL constraint on z if probabilistic

        # qf and encoder update (note encoder does not get grads from policy or vf)
        self.qf1_optimizer.zero_grad()
        self.qf2_optimizer.zero_grad()
        # rewards_flat = rewards.view(self.batch_size, -1)
        rewards_flat = rewards

        # scale rewards for Bellman update
        reward_scale = 10
        rewards_flat = rewards_flat * reward_scale
        # terms_flat = terms.view(self.batch_size, -1)

        q_target = rewards_flat + conts * self.discount * target_v_values
        qf_loss = torch.mean((q1_pred - q_target) ** 2) + torch.mean((q2_pred - q_target) ** 2)
        qf_loss.backward()
        self.qf1_optimizer.step()
        self.qf2_optimizer.step()
        ###
        # self.encoder_optimizer.step()
        # self.curl_optimizer.step()
        # compute min Q on the new actions
        min_q_new_actions = self._min_q(obs, dpolicy_outputs, new_cactions)

        # vf update
        v_target = min_q_new_actions - log_pi
        vf_loss = self.vf_criterion(v_pred, v_target.detach())
        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()
        self._update_target_network()

        # policy update
        log_policy_target = min_q_new_actions
        alpha = 1
        policy_loss = (
                alpha*log_pi - log_policy_target
        ).mean()

        mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean()
        std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean()
        pre_tanh_value = cpolicy_outputs[-1]
        pre_activation_reg_loss = self.policy_pre_activation_weight * (
            (pre_tanh_value**2).sum(dim=1).mean()
        )
        policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
        policy_loss = policy_loss + policy_reg_loss

        self.dpolicy_optimizer.zero_grad()
        self.cpolicy_optimizer.zero_grad()
        # if self.primitive:
        #     self.ppolicy_optimizer.zero_grad()
        policy_loss.backward()
        self.dpolicy_optimizer.step()
        self.cpolicy_optimizer.step()
        # if self.primitive:
        #     self.ppolicy_optimizer.step()

        return np.mean(u.get_numpy(qf_loss)), np.mean(u.get_numpy(vf_loss)), np.mean(u.get_numpy(policy_loss))

    def _min_q(self, obs, dactions, cactions):
        q1 = self.qf1(obs, dactions, cactions)
        q2 = self.qf2(obs, dactions, cactions)
        min_q = torch.min(q1, q2)
        return min_q
    
    def _update_target_network(self):
        u.soft_update_from_to(self.vf, self.target_vf, self.soft_target_tau)

    def vae_train(self, train_step, buffer, batch_size=64, embed_lr=1e-4):
        initial_losses = []

        for counter in range(int(train_step) + 10):
            losses = []
            state, discrete_action, parameter_action, \
            all_parameter_action, discrete_emb, parameter_emb, \
            next_state, state_next_state, reward, not_done = buffer.sample(batch_size)

            vae_loss, recon_loss_s, \
            recon_loss_c, KL_loss = self.action_rep.unsupervised_loss(
                state,
                discrete_action.reshape(1, -1).squeeze().long(),
                parameter_action,
                state_next_state,
                batch_size, embed_lr,
                self._dyanmics, self._reward, self._continue, reward, not_done)
            
            losses.append(vae_loss)
            initial_losses.append(np.mean(losses))

            if counter % 100 == 0 and counter >= 100:
                # print("load discrete embedding", action_rep.discrete_embedding())
                print("vae_loss, recon_loss_s, recon_loss_c, KL_loss", vae_loss, recon_loss_s, recon_loss_c, KL_loss)
                print("Epoch {} loss:: {}".format(counter, np.mean(initial_losses[-50:])))

            # Terminate initial phase once action representations have converged.
            if len(initial_losses) >= train_step and np.mean(initial_losses[-5:]) + 1e-5 >= np.mean(initial_losses[-10:]):
                # print("vae_loss, recon_loss_s, recon_loss_c, KL_loss", vae_loss, recon_loss_s, recon_loss_c, KL_loss)
                # print("Epoch {} loss:: {}".format(counter, np.mean(initial_losses[-50:])))
                # print("Converged...", len(initial_losses))
                break
            # if vae_save_model:
            #     if counter % 1000 == 0 and counter >= 1000:
            #         title = "vae" + "{}".format(str(counter))
            #         action_rep.save(title, save_dir)
            #         print("embedding save model")

        state_, discrete_action_, parameter_action_,\
        all_parameter_action, discrete_emb, parameter_emb,\
        next_state, state_next_state_, reward, not_done = self.embed_replay_buffer.sample(batch_size=5000)

        c_rate, recon_s = self.action_rep.get_c_rate(
            state_, 
            discrete_action_.reshape(1, -1).squeeze().long(), 
            parameter_action_,
            state_next_state_, 
            batch_size=5000, 
            range_rate=2
            )
        
        self.c_rate = c_rate
        
        return c_rate, recon_s


    def debug_grad(self):
        models = {
                # "state": self._dyanmics, 
                #   "reward": self._reward, 

                  "dpolicy": self._agent.discrete_policy, 
                  "cpolicy": self._agent.continuous_policy, 
                  "Q1": self.qf1, 
                  "Q2": self.qf2,
                  "V": self.vf,
                }
        for model_name, model in models.items():
            print(model_name)
            for name, param in model.named_parameters():
                if param.requires_grad:
                    print(name, param.data)
                    break












