import torch as th
import numpy as np
from copy import deepcopy

class ADLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.log_stats_t = -self.args.learner_log_interval - 1
        self.c_rate = None
        self.recon_s = None
    
    def train(self, replay_buffer, t_env: int, episode_num: int):
        critic_loss, actor_loss = self.mac.train_agent(replay_buffer, self.args.hyar_batch_size)
        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("critic_loss", critic_loss, t_env)
            if actor_loss is not None:
                self.logger.log_stat("actor_loss", actor_loss, t_env)

    def vae_train(self, replay_buffer, batch_size, train_step, log_vae_loss=False):
        initial_losses = []
        converge_judge = 0
        minimum_mean_loss = float('inf')
        for counter in range(int(train_step) + 20000):
            losses = []
            state, discrete_action, parameter_action, discrete_emb, parameter_emb, next_state, state_next_state, reward, not_done = replay_buffer.sample_tuple(
                batch_size)
            vae_loss, recon_loss_s, recon_loss_c, KL_loss = self.mac.vae.unsupervised_loss(state,
                                                                                        discrete_action.reshape(1, -1).squeeze().long(),
                                                                                        parameter_action,
                                                                                        state_next_state,
                                                                                        batch_size, self.args.embed_lr)
            losses.append(vae_loss)
            initial_losses.append(np.mean(losses))

            if counter % 500 == 0 and counter >= 500:
                self.logger.console_logger.info("vae_loss: {}, recon_loss_s: {}, recon_loss_c: {}, KL_loss: {}".format(vae_loss, recon_loss_s, recon_loss_c, KL_loss))
                self.logger.console_logger.info("Epoch {}/{} loss:: {}".format(counter, int(train_step) + 20000, np.mean(initial_losses[-50:])))
                if log_vae_loss:
                    self.logger.log_stat("vae/vae_loss", vae_loss.item(), counter)
                    self.logger.log_stat("vae/recon_loss_s", recon_loss_s.item(), counter)
                    self.logger.log_stat("vae/recon_loss_c", recon_loss_c.item(), counter)
                    self.logger.log_stat("vae/KL_loss", KL_loss.item(), counter)

            # Terminate initial phase once action representations have converged.
            if len(initial_losses) >= train_step and (len(initial_losses) - train_step) % 50 == 0:
                if np.mean(initial_losses[-50:]) < minimum_mean_loss:
                    minimum_mean_loss = np.mean(initial_losses[-50:])
                    converge_judge = 0
                else:
                    converge_judge += 1
                
                if converge_judge > 10:
                    self.logger.console_logger.info("Converged... {}".format(len(initial_losses)))
                    break

        state_, discrete_action_, parameter_action_, discrete_emb, parameter_emb, next_state, state_next_state_, reward, not_done = replay_buffer.sample_tuple(
            batch_size=self.args.vae_get_c_rate_batch_size)
        self.c_rate, self.recon_s = self.mac.vae.get_c_rate(state_, discrete_action_.reshape(1, -1).squeeze().long(), parameter_action_,
                                                state_next_state_, batch_size=self.args.vae_get_c_rate_batch_size, range_rate=2)
        self.mac.c_rate = deepcopy(self.c_rate)
        self.mac.recon_s = deepcopy(self.recon_s)
    
    def cuda(self):
        self.mac.cuda()

    def save_models(self, path):
        self.mac.save_models(path)

    def load_models(self, path):
        self.mac.load_models(path)
