from time import time

from .buffer import FactBuffer, FactEpisodeBuffer, _split_mask_batch, _pad_mask_data
from .config import get_alg_args
from .mixer import AdditiveMixer, MonotonicMixer, QPLEXMixer
from common.imports import *
from common.logger import Logger
from common.utils import linear_schedule, to_list_tensor, to_agent_shape
from ..dec_MAHDDRQN.agent import DuelRQNetwork, RQNetwork

from environments.eval import Evaluator

class FactMAHDDRQN:
    def __init__(self, venv: gym.Env, run_name: str, start_time: float, args: Dict[str, Any], device):
        args = ap.Namespace(**vars(args), **vars(get_alg_args()))
        
        n_envs = args.n_envs
        assert args.train_freq % n_envs == 0, \
            f"Invalid train frequency: {args.train_freq}. Must be multiple of n_envs {n_envs}"
        logger = Logger(run_name, args) if args.track else None
        evaluator = Evaluator(args, logger, device)

        # Initialize the RQ-networks, optimizer, and self.buffer
        self.qnet, self.tg_qnet, self.buffer = [{} for _ in range(3)]
        qnet_init = DuelRQNetwork if args.dueling else RQNetwork

        n_agents = len(venv.single_observation_space)
        agents = range(n_agents)
        net_params = []
        for a in agents:
            self.qnet[a] = qnet_init(venv, args, a).to(device)    # th.jit.script
            self.tg_qnet[a] = deepcopy(self.qnet[a])
            for w_tg_qnet in self.tg_qnet[a].parameters(): w_tg_qnet.requires_grad = False
            net_params += list(self.qnet[a].parameters())
            self.buffer[a] = FactBuffer(venv, args, a, device)

        self.detach = args.detach
        self.extra_info = args.extra_info
        self.reward_scheme = args.reward_scheme
        state_dim = len(venv.call('get_state')[0])
        if self.extra_info == 1: state_dim = sum(space.n for space in venv.single_observation_space)    # we use the joint obs as extra info
        if args.mixer == 'qmix': 
            self.mixer = MonotonicMixer(n_agents, state_dim) 
            self.tg_mixer = deepcopy(self.mixer)
            self.mix = self._forward_mix
        elif args.mixer == 'qplex': 
            act_dim = sum([action_space.n for action_space in venv.single_action_space])
            self.mixer = QPLEXMixer(n_agents, state_dim, act_dim)
            self.tg_mixer = deepcopy(self.mixer)
            self.mix = self._attention_mix
        else: 
            self.mixer, self.tg_mixer = AdditiveMixer(), AdditiveMixer()
            self.mix = self._forward_mix

        for w_tg_mixer in self.tg_mixer.parameters(): w_tg_mixer.requires_grad = False
        net_params += list(self.mixer.parameters())

        qnet_optim = optim.Adam(net_params, lr=args.lr)    

        init_step = 1
        global_step = 0
        obs, _ = venv.reset()
        obs = to_list_tensor(obs, device)

        ep_buffer = [{a: FactEpisodeBuffer(venv, args, a, device) for a in agents} for _ in range(n_envs)]

        m_obs = {a: th.empty(obs[a].shape).to(device) for a in agents}   # Must keep agents separate to deal with inhomogeneous spaces
        m_act = {a: th.empty((n_envs), dtype=int).to(device) for a in agents}
        m_valid = {a: th.ones((n_envs), dtype=bool).to(device) for a in agents}
        m_reward = {a: th.zeros((n_envs)).to(device) for a in agents}
        m_j_valid = th.ones((n_envs), dtype=bool).to(device)
        m_h = {a: th.zeros([n_envs, args.h_size]).to(device) for a in agents}
        m_next_h = deepcopy(m_h)   
        m_j_step = th.zeros((n_envs), dtype=int).to(device)
        m_j_gamma = th.ones((n_envs)).to(device)
        m_j_reward = th.zeros((n_envs)).to(device)

        rew, episode, length, updating, updating_tg = 0, 0, 0, 0, 0
        try:
            for step in range(init_step, int(args.total_timesteps // n_envs)):
                global_step += n_envs
                epsilon = linear_schedule(
                    args.eps_start, args.eps_end, args.eps_decay_frac * args.total_timesteps, global_step
                )
                hyst = linear_schedule(
                    args.hyst_start, args.hyst_end, args.hyst_frac * args.total_timesteps, global_step, False
                )
                #print(f"Step: {global_step}")
                with th.no_grad():
                    #print(f"valid: {m_valid}")
                    for a in agents:
                        m_obs[a][m_valid[a]] = obs[a][m_valid[a]]
                        m_h[a][m_valid[a]] = m_next_h[a][m_valid[a]] 
                        (m_act[a][m_valid[a]], m_next_h[a][m_valid[a]]) = self.qnet[a].get_action(m_obs[a][m_valid[a]], m_h[a][m_valid[a]], epsilon)
                        m_reward[a][m_valid[a]] = th.zeros(m_reward[a][m_valid[a]].shape).to(device)

                    m_j_step[m_j_valid] = th.zeros(m_j_step[m_j_valid].shape, dtype=int).to(device)
                    m_j_gamma[m_j_valid] = th.ones(m_j_gamma[m_j_valid].shape).to(device)
                    m_j_reward[m_j_valid] = th.zeros(m_j_reward[m_j_valid].shape).to(device)

                #print(f"m_obs: {m_obs}")
                #print(f"m_act: {m_act}")
                state = to_list_tensor(venv.call('get_state'), device, add_dim=False)
                next_obs, reward, term, trunc, info = venv.step(m_act)
                next_state = to_list_tensor(venv.call('get_state'), device, add_dim=False)

                ##print(m_act)
                #input()
                length += 1
                rew += np.mean(reward[0])

                next_obs = to_list_tensor(next_obs, device)
                reward = to_agent_shape(reward, device, float) # from 1 array for each env to (n_agents, n_envs)
                
                term = th.tensor(term).to(device)
                m_valid = {a: md for a, md in enumerate(to_agent_shape(info['mac_done'], device, bool))}
                m_j_valid = th.max(th.stack(list(m_valid.values())), dim=0)[0].to(device)
                m_j_step += 1
                #m_j_reward += args.gamma ** (m_j_step - 1) * (sum(reward)/n_agents)
                m_j_reward += sum(reward)/n_agents
                m_j_gamma = args.gamma ** (m_j_step - 1)

                # Handle termination/truncation
                real_next_obs = deepcopy(next_obs)
                real_m_valid = deepcopy(m_valid)
                for a, idx in enumerate(agents): 
                    m_reward[a] += reward[idx]

                    for i, done in enumerate(np.logical_or(term, trunc)):
                        if done:     
                            real_next_obs[a][i] = th.tensor(info["final_observation"][i][a]).to(device)
                            real_m_valid[a][i] = th.tensor(info["final_info"][i]['mac_done'][a]).to(device)

                # It's ugly to re-iterate on the same stuff, but we must update all the mac_done before computing the joint valid...
                real_m_j_valid = th.max(th.stack(list(real_m_valid.values())), dim=0)[0].to(device)
                for a in agents: 
                    for i, done in enumerate(np.logical_or(term, trunc)):
                        ep_buffer[i][a].store(
                            m_obs[a][i],
                            m_act[a][i],
                            real_m_valid[a][i],
                            m_reward[a][i],
                            real_m_j_valid[i],
                            m_j_reward[i],
                            m_j_gamma[i],
                            real_next_obs[a][i] if real_m_valid[a][i] else m_obs[a][i],
                            term[i],
                            m_h[a][i],
                            m_next_h[a][i],   # Could handle everything with m_h, but it's a bit more annoying 
                            state[i],
                            next_state[i]
                        )   

                        if done: 
                            ep_buffer[i][a].pad_done()     # Make the ep self.buffer multiple of seq_len for sampling
                            self.buffer[a].store(ep_buffer[i][a])
                            ep_buffer[i][a] = FactEpisodeBuffer(venv, args, a, device)
                            
                            # step, gamma, reward will be resetted above since m_valid is true after reset)
                            m_next_h[a][i] = th.zeros([1, args.h_size]).to(device)
                            #quit()
                #print(f"m_reward: {m_reward}")
                #print(f"m_gamma: {m_gamma}")
                #input()

                if np.logical_or(term, trunc)[0]:
                    #print(f"Episode: {episode}, glob_step: {global_step}, reward: {rew:.1f}, length: {length}")      
                    rew, length = 0, 0
                    episode +=1

                obs = next_obs
                
                if global_step % args.eval_freq == 0:
                    evaluator.evaluate(global_step, self.qnet)
                    if args.verbose: print(f"SPS={int(global_step / (time() - start_time))}")

                if global_step > args.learning_starts:
                    if global_step % args.train_freq == 0:
                        
                        mix_y, mix_qvalues = self.mix(agents, args.gamma, venv.single_action_space)

                        td_err = (mix_y - mix_qvalues)
                        if args.hysteretic: td_err = th.max(hyst * td_err, td_err)
                        qnet_loss = (td_err**2).mean()

                        qnet_optim.zero_grad(True)
                        qnet_loss.backward()
                        #th.nn.utils.clip_grad_norm_(net_params, 10.)
                        qnet_optim.step()

                    if global_step % args.tg_qnet_freq == 0:
                        for i, a in enumerate(agents):
                            for tg_qnet_param, qnet_param in zip(self.tg_qnet[a].parameters(), self.qnet[a].parameters()): 
                                tg_qnet_param.data.copy_(
                                    args.tau * qnet_param.data + (1.0 - args.tau) * tg_qnet_param.data
                                )
                        for tg_mixer_param, mixer_param in zip(self.tg_mixer.parameters(), self.mixer.parameters()): 
                            tg_mixer_param.data.copy_(
                                args.tau * mixer_param.data + (1.0 - args.tau) * tg_mixer_param.data
                            )

        finally:
            if logger: logger.close()
            venv.close()

    def _forward_mix(self, agents, gamma, *args):
        tmp_qvalues, tmp_tg_qvalues_ = [], []
        m_j_obs, next_m_j_obs = [], []
        m_j_reward = []
        for a in agents:
            # Sequential sample (i.e., for each agent, sample from the same episodes and steps)
            if a == 0: batch, ep_idxs, start_idxs = self.buffer[a].sample()
            else: batch, _, _ = self.buffer[a].sample(ep_idxs, start_idxs)

            _split_mask_batch(batch)
            (
                pad_observations, mask_observations, 
                pad_observations_, mask_observations_, 
                pad_histories, pad_histories_
            ) = _pad_mask_data(batch)

            m_j_obs.append(pad_observations[mask_observations])
            next_m_j_obs.append(pad_observations_[mask_observations_])

            with th.no_grad():
                # Get only non padded entries
                tg_qvalues_, _ = self.tg_qnet[a](pad_observations_, pad_histories_)
                qvalues_, _ = self.qnet[a](pad_observations_, pad_histories_)
            
                tg_qvalues_ = tg_qvalues_[mask_observations_]
                qvalues_ = qvalues_[mask_observations_]
            
                # Taking a new m_act only if the previous one is over
                tmp_actions_ = th.argmax(qvalues_, dim=-1, keepdims=True).type(th.int64)
                actions_ = th.where(batch['m_valids'].view(-1, 1), tmp_actions_, batch['m_acts'])
 
                tg_qvalues_ = tg_qvalues_.gather(-1, actions_)
                if self.detach == 2 or self.detach == 3: 
                    tg_qvalues_ = th.where(batch['m_valids'].view(-1, 1), tg_qvalues_, th.tensor(0))          
                tmp_tg_qvalues_.append(tg_qvalues_)

                if self.reward_scheme == 1: 
                    m_j_reward.append(th.where(batch['m_valids'].view(-1, 1), batch['m_rewards'], th.tensor(0))    )

            qvalues, _ = self.qnet[a](pad_observations, pad_histories)
           
            qvalues = qvalues[mask_observations]
            qvalues = qvalues.gather(-1, batch['m_acts'])
            if self.detach == 1: tmp_qvalues.append(th.where(batch['m_valids'].view(-1, 1), qvalues, qvalues.detach()))
            elif self.detach == 2: tmp_qvalues.append(th.where(batch['m_valids'].view(-1, 1), qvalues, th.tensor(0)))
            elif self.detach == 3:
                qvalues = th.where(batch['m_valids'].view(-1, 1), qvalues, qvalues.detach())
                tmp_qvalues.append(th.where(batch['m_valids'].view(-1, 1), qvalues, th.tensor(0)))
            else: tmp_qvalues.append(qvalues)
            
        #parser.add_argument("--extra-info", type=int, default=0, help="Whether to use the state as extra mixer info (0), the joint history (1), a vector of zeros (2) or a vector of ones (3)")

        with th.no_grad():
            if self.extra_info == 1: next_extra_info = th.cat(next_m_j_obs, dim=1)
            elif self.extra_info == 2: next_extra_info = th.zeros_like(batch['next_states'])
            elif self.extra_info == 2: next_extra_info = th.ones_like(batch['next_states'])
            else: next_extra_info = batch['next_states'] 

            if isinstance(self.mixer, MonotonicMixer): 
                mix_tg_qvalue = self.tg_mixer(th.hstack(tmp_tg_qvalues_), next_extra_info)
            else: mix_tg_qvalue = self.tg_mixer(th.hstack(tmp_tg_qvalues_))

            if self.reward_scheme == 1: 
                m_j_reward = th.stack(m_j_reward).sum(dim=0)
            else: m_j_reward = batch['m_j_rewards']

            mix_y = m_j_reward + gamma * batch['m_j_gammas'] * mix_tg_qvalue * (1 - batch['dones'])

        if self.extra_info == 1: extra_info = th.cat(m_j_obs, dim=1)
        elif self.extra_info == 2: extra_info = th.zeros_like(batch['states'])
        elif self.extra_info == 2: extra_info = th.ones_like(batch['states'])
        else: extra_info = batch['states'] 

        if isinstance(self.mixer, MonotonicMixer): 
            mix_qvalues = self.mixer(th.hstack(tmp_qvalues), extra_info)
        else: mix_qvalues = self.mixer(th.hstack(tmp_qvalues))

        return mix_y, mix_qvalues
                 
    def _attention_mix(self, agents, gamma, action_spaces):
        tmp_qvalues, tmp_tg_qvalues_ = [], []
        tmp_max_qvalue, tmp_max_tg_qvalue, tmp_oh_actions, tmp_tg_oh_actions = [[] for _ in range(4)]
        m_j_obs, next_m_j_obs = [], []
        m_j_reward = []
        for a in agents:
            # Sequential sample (i.e., for each agent, sample from the same episodes and steps)
            if a == 0: batch, ep_idxs, start_idxs = self.buffer[a].sample()
            else: batch, _, _ = self.buffer[a].sample(ep_idxs, start_idxs)

            _split_mask_batch(batch)
            (
                pad_observations, mask_observations, 
                pad_observations_, mask_observations_, 
                pad_histories, pad_histories_
            ) = _pad_mask_data(batch)

            m_j_obs.append(pad_observations[mask_observations])
            next_m_j_obs.append(pad_observations_[mask_observations_])

            with th.no_grad():
                # Get only non padded entries
                tg_qvalues_, _ = self.tg_qnet[a](pad_observations_, pad_histories_)
                qvalues_, _ = self.qnet[a](pad_observations_, pad_histories_)

                tg_qvalues_ = tg_qvalues_[mask_observations_]
                qvalues_ = qvalues_[mask_observations_]
            
                # Taking a new m_act only if the previous one is over
                tmp_actions_ = th.argmax(qvalues_, dim=-1, keepdims=True).type(th.int64)
                actions_ = th.where(batch['m_valids'].view(-1, 1), tmp_actions_, batch['m_acts'])

                tg_qvalues_ = tg_qvalues_.gather(-1, actions_)
                if self.detach == 2 or self.detach == 3: 
                    tg_qvalues_ = th.where(batch['m_valids'].view(-1, 1), tg_qvalues_, th.tensor(0))          
                tmp_tg_qvalues_.append(tg_qvalues_)
            
                tmp_oh_actions.append(F.one_hot(batch['m_acts'], num_classes=action_spaces[a].n).squeeze(1))
                tmp_tg_oh_actions.append(F.one_hot(actions_, num_classes=action_spaces[a].n).squeeze(1))

                tmp_max_tg_qvalue.append(th.max(tg_qvalues_, dim=-1)[0])

                if self.reward_scheme == 1: 
                    m_j_reward.append(th.where(batch['m_valids'].view(-1, 1), batch['m_rewards'], th.tensor(0)))


            qvalues, _ = self.qnet[a](pad_observations, pad_histories)
            qvalues = qvalues[mask_observations]

            qvalues = qvalues.gather(-1, batch['m_acts'])
            if self.detach == 1: tmp_qvalues.append(th.where(batch['m_valids'].view(-1, 1), qvalues, qvalues.detach()))
            elif self.detach == 2: tmp_qvalues.append(th.where(batch['m_valids'].view(-1, 1), qvalues, th.tensor(0)))
            elif self.detach == 3:
                qvalues = th.where(batch['m_valids'].view(-1, 1), qvalues, qvalues.detach())
                tmp_qvalues.append(th.where(batch['m_valids'].view(-1, 1), qvalues, th.tensor(0)))
            else: tmp_qvalues.append(qvalues)

            tmp_max_qvalue.append(th.max(qvalues, dim=-1)[0])

        with th.no_grad():
            tmp_tg_qvalues_ = th.hstack(tmp_tg_qvalues_)
            tmp_tg_oh_actions = th.hstack(tmp_tg_oh_actions)
            tmp_max_tg_qvalue = th.hstack(tmp_max_tg_qvalue)

            if self.extra_info == 1: next_extra_info = th.cat(next_m_j_obs, dim=1)
            elif self.extra_info == 2: next_extra_info = th.zeros_like(batch['next_states'])
            elif self.extra_info == 2: next_extra_info = th.ones_like(batch['next_states'])
            else: next_extra_info = batch['next_states'] 

            mix_tg_v = self.tg_mixer(tmp_tg_qvalues_, next_extra_info, is_v=True)
            mix_tg_adv = self.tg_mixer(tmp_tg_qvalues_, 
                                       next_extra_info, 
                                       tmp_tg_oh_actions,
                                       tmp_max_tg_qvalue,
                                       is_v=False
                                       )
            
            mix_tg_qvalue = mix_tg_v + mix_tg_adv

            if self.reward_scheme == 1: 
                m_j_reward = th.stack(m_j_reward).sum(dim=0)
            else: m_j_reward = batch['m_j_rewards']

            mix_y = m_j_reward + gamma * batch['m_j_gammas'] * mix_tg_qvalue * (1 - batch['dones'])

        tmp_qvalues = th.hstack(tmp_qvalues)
        tmp_oh_actions = th.hstack(tmp_oh_actions)
        tmp_max_qvalue = th.hstack(tmp_max_qvalue)

        if self.extra_info == 1: extra_info = th.cat(m_j_obs, dim=1)
        elif self.extra_info == 2: extra_info = th.zeros_like(batch['states'])
        elif self.extra_info == 2: extra_info = th.ones_like(batch['states'])
        else: extra_info = batch['states'] 

        mix_v = self.mixer(tmp_qvalues, extra_info, is_v=True)
        mix_adv = self.mixer(tmp_qvalues, 
                             extra_info, 
                             tmp_oh_actions,
                             tmp_max_qvalue,
                             is_v=False
                             )
        
        mix_qvalues = mix_v + mix_adv

        return mix_y, mix_qvalues
    