import copy
from re import I
from components.episode_buffer import EpisodeBatch
from modules.mixers.vdn import VDNMixer
from modules.mixers.qmix import QMixer
import torch as th
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import RMSprop, Adam
import time 
import numpy as np

def to_cuda(module, device=None):
    """Assume module in cpu"""
    if device is None:
        return module.cuda()
    elif device == 'cpu':
        return module
    else:
        return module.to(device)

class PriorRoleVqvaeLearner:
    def __init__(self, mac, vqvae, scheme, logger, args):
        self.args = args
        self.logger = logger
        self.device = self.args.device
        self.bs     = args.offline_batch_size
        
        self.default_node = self.args.n_codes
        
        self.nk = int(self.args.n_codes / self.args.n_cluster)

        self.state_dim = scheme["state"]["vshape"]
        #input_shape    = self._get_input_shape(scheme) # input_shape includes agent-id (if available)

        self.vae = vqvae
        self.vae_params = list(self.vae.parameters())
        self.vae_optimizer = Adam(params=self.vae_params, lr=args.lr)
        
        self.update_vqvae = False
        self.update_codebook = False
        self.last_vqvae_update_episode = 0
        self.last_codebook_update_episode = 0
        self.last_target_update_episode = 0

        self.vae_log_t = self.log_stats_t = -self.args.learner_log_interval - 1
        
        self.vae_losses      = th.tensor(0.0).to(self.args.device)   
        self.ce_losses       = th.tensor(0.0).to(self.args.device)   
        self.vq_losses       = th.tensor(0.0).to(self.args.device)   
        self.commit_losses   = th.tensor(0.0).to(self.args.device)   
        self.coverage_losses = th.tensor(0.0).to(self.args.device)   
        
        self.criterion = nn.CrossEntropyLoss()        

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int, seq_centroid=None, RBS_cluster=None, buffer_seq_labels=None, f_classifier=None):
        # Get the relevant quantities        
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]   
        terminated       = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]
        actions_onehot = batch["actions_onehot"][:, :-1]
       
        if buffer_seq_labels is not None:
            seq_labels = buffer_seq_labels
        else:
            seq_labels = None

        mean_acc_t0 = None

        # backward reward and reward sum generation for codebook update                
        rewards_th = th.tensor(batch["reward"]).to(self.device).squeeze(-1)
        sum_rewards = th.sum(th.tensor( rewards_th ).to(self.device), axis=1) # [bs]
        reward_tgo  = th.zeros_like(rewards_th ).to(self.device)

        #.. reverse sequence for reward-to-go computation
        for t in range(batch.max_seq_length-1, -1, -1):
            if t == batch.max_seq_length-1:
                reward_tgo[:, t] = rewards_th[:,t]
            else:
                reward_tgo[:, t] = rewards_th[:,t] + self.args.gamma*reward_tgo[:, t+1]
        
        self.update_vqvae = True
        self.last_vqvae_update_episode = episode_num

        if ((episode_num - self.last_codebook_update_episode) / self.args.codebook_update_interval >= 1.0):
            self.update_codebook = True
            self.last_codebook_update_episode = episode_num
        else:
            self.update_codebook = False       
            if t_env <= self.args.buffer_update_time: # update codebook at early trainig time
                self.update_codebook = True

        #vae_losses      = th.tensor(0.0).to(self.args.device)        
        #.. for monitoring
        ce_losses       = th.tensor(0.0).to(self.args.device)
        vq_losses       = th.tensor(0.0).to(self.args.device)
        commit_losses   = th.tensor(0.0).to(self.args.device)
        coverage_losses = th.tensor(0.0).to(self.args.device)
        
        visit_nodes =[]
        vae_losses  = th.tensor(0.0).to(self.args.device)

        buf_state_input = []
        buf_recon       = []
        buf_z_e         = []
        buf_latent_emb  = []
        buf_Cqt_hat     = []
        # buf_Cqt_target  = []

        for t in range(batch.max_seq_length):
            state_input = th.tensor(batch["obs"][:, t]).to(self.device)
            recon, z_e, latent_emb, argmin = self.vae(state_input) # [bs,dim]
                    
            sums = th.sum(state_input,dim=1)
            #zero_index = th.nonzero(sums==0, as_tuple=False).squeeze()
            zero_index = th.nonzero(sums==0, as_tuple=False)
            if len(zero_index) > 0 :
                argmin[zero_index] = self.args.n_codes
                    
            visit_nodes.append(argmin)
            buf_state_input.append(state_input)
            buf_recon.append(recon)      
            buf_z_e.append(z_e)        
            buf_latent_emb.append(latent_emb)                    
          
        visit_nodes    = th.stack(visit_nodes, dim=1)  # Concat over time # sequence of trajectory
        th_state_input = th.stack(buf_state_input, dim=0 )
        th_recon       = th.stack(buf_recon      , dim=0 )
        th_z_e         = th.stack(buf_z_e        , dim=0 )
        th_latent_emb  = th.stack(buf_latent_emb , dim=0 )                   
      
        # 2. VQVAE loss computation sampling loop =====================================================
        for t in range(batch.max_seq_length):
            #.. compute time&trajectory dependent indexing: ndx computation----------
            if t == 0:
                dn = (self.nk / batch.max_seq_length) # dn
                dn_r = self.nk / batch.max_seq_length
                # dr = self.nk % batch.max_seq_length
                # ids = int(dn*batch.max_seq_length)
            
            ndx_cluster = {}
            max_len = 0
            for k in range(self.args.n_cluster):
                ids = self.nk * (k) + int(dn*t)
                ide = self.nk * (k) + int(dn*(t+1))
                if dn >= 1:       
                    # ndx = np.arange(dn*t, dn*(t+1), 1) 
                    ndx_cluster[k] = np.arange(ids, ide, 1)      
                    # if t < dr:
                    #     ndx[k] = np.append(ndx, np.array(ids+t))
                
                else:
                    #ids = int(self.nk * (k) + dn*t)
                    ndx_cluster[k] = np.array([ids])
                max_len = max(max_len, len(ndx_cluster[k]))     
                                     
            ndx = self.default_node*np.ones((self.args.offline_batch_size, max_len))
            #ndx = self.default_node*np.ones(self.args.offline_batch_size)
            for kd in range(self.args.offline_batch_size):
                if seq_labels is None:
                    #.. consider all trajectory dependent VQ codebook here by randomly draw numbers
                    k_id = np.random.randint(0, self.args.n_cluster)
                    ndx[kd,:] = ndx_cluster[k_id]
                    #ndx[kd] = ndx_cluster[k_id]
                else:
                    k_id = seq_labels[kd].item()
                    if k_id == -1:
                        k_id = np.random.randint(0, self.args.n_cluster)
                    ndx[kd,:] = ndx_cluster[k_id]
                        
            #.. compute timedependent indexing only ----------
            #--------------------------------------------------------------------------------

            #..matching current values
            state_input_t = th_state_input[t]
            recon_t       = th_recon[t]
            z_e_t         = th_z_e[t]         
            latent_emb_t  = th_latent_emb[t]
            argmin_t      = visit_nodes.permute((1,0))[t] # [B,t] --> [t,B]

            #.. trajectory dependent Cqt update
            # if seq_labels is not None:
            #     indices_not_clustered = np.equal(seq_labels, -1)
            #     print(indices_not_clustered.shape)
            if self.update_codebook:
                self.vae.codebook_update_tdvq(argmin_t, t_env, sum_rewards, reward_tgo[:,t], seq_labels=seq_labels) # include for-batch          

            vae_loss, ce_loss, vq_loss, commit_loss, coverage_loss = \
                self.vae.loss_function(state_input_t, recon_t, z_e_t, latent_emb_t, ndx=ndx)                       
                
            # this results are already computed by taking average in batch-wise
            vae_losses      += vae_loss 
            ce_losses       += ce_loss
            vq_losses       += vq_loss
            commit_losses   += commit_loss
            coverage_losses += coverage_loss
        # end sampling loop ==============================================================================
           
        vae_losses      /= batch.max_seq_length # compute average by timestep
        ce_losses       /= batch.max_seq_length
        vq_losses       /= batch.max_seq_length
        commit_losses   /= batch.max_seq_length
        coverage_losses /= batch.max_seq_length

        self.ce_losses       = ce_losses       
        self.vq_losses       = vq_losses       
        self.commit_losses   = commit_losses   
        self.coverage_losses = coverage_losses 
        # End vqvae training =============================================================================

        #.. VQ-VAE learning
        if self.update_vqvae:            
            self.vae_optimizer.zero_grad()
            grad_norm = th.nn.utils.clip_grad_norm_(self.vae_params, self.args.grad_norm_clip)
            
            vae_losses.backward()
            self.vae_losses = vae_losses
            
            self.vae_optimizer.step()  


        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("vae_loss", self.vae_losses.item(), t_env)                    
            self.logger.log_stat("ce_loss", self.ce_losses.item(), t_env)
            self.logger.log_stat("vq_loss", self.vq_losses.item(), t_env)
            self.logger.log_stat("commit_loss", self.commit_losses.item(), t_env)
            self.logger.log_stat("coverage_loss", self.coverage_losses.item(), t_env)

            self.log_stats_t = t_env                       
        
        visit_nodes_padded = self.args.n_codes*th.ones(visit_nodes.size()[0], self.args.episode_limit+1)
        visit_nodes_padded[:,:batch.max_seq_length] = visit_nodes
            
        return visit_nodes_padded, mean_acc_t0 # cpu
                
    def cuda(self):
        self.vae.cuda()

    def save_models(self, path):
        th.save(self.vae.state_dict(), "{}/vae.th".format(path))
        th.save(self.vae.emb.state_dict(), "{}/codebook.th".format(path))
        
    def _get_input_shape(self, scheme):
        input_shape = scheme["obs"]["vshape"]
        if self.args.obs_last_action:
            input_shape += scheme["actions_onehot"]["vshape"][0]
        if self.args.obs_agent_id:
            input_shape += self.args.n_agents

        return input_shape

    def load_models(self, path):
        self.vae.load_state_dict(th.load("{}/vae.th".format(path), map_location=lambda storage, loc: storage))
        self.vae.emb.load_state_dict(th.load("{}/codebook.th".format(path), map_location=lambda storage, loc: storage))