import copy
from components.episode_buffer import EpisodeBatch
import numpy as np
import torch as th
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Categorical
from modules.critics import REGISTRY as critic_resigtry
from components.standarize_stream import RunningMeanStd
from components.transforms import OneHot

class RewardTranslator(nn.Module):
    def __init__(self, init, device):
        super(RewardTranslator, self).__init__()
        self.bias = nn.parameter.Parameter(th.tensor([[init]], device=device))
        
    def forward(self, rewards):
        return rewards + self.bias


class MAOSDQNLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.logger = logger

        self.mac = mac
        self.target_mac = copy.deepcopy(mac)
        self.target_mac.action_selector.epsilon = self.args.sample_epsilon
        
        self.agent_params = list(mac.parameters())
        self.agent_optimiser = Adam(params=self.agent_params, lr=args.lr)

        self.critic = critic_resigtry[args.critic_type](scheme, args)
        self.critic_params = list(self.critic.parameters())
        self.critic_optimiser = Adam(params=self.critic_params, lr=args.lr)

        self.last_target_update_step = 0
        self.training_steps = 0
        self.log_stats_t = -self.args.learner_log_interval - 1
        
        self.device = "cuda" if args.use_cuda else "cpu"
        if self.args.standardise_rewards:
            self.rew_ms = RunningMeanStd(shape=(1,), device=self.device)
        if self.args.translate_rewards:
            self.rew_tr = RewardTranslator(init=-np.log(self.n_actions)/args.alpha, device=self.device)
            self.rew_optimiser = Adam(params=self.rew_tr.parameters(), lr=args.lr)
        if self.args.standardise_returns:
            self.ret_ms = RunningMeanStd(shape=(1,), device=self.device)
            
        self.grad_norm = th.tensor([0], device=self.device)

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        states = batch["state"].view(self.args.batch_size, -1)
        obs = batch["obs"].view(self.args.batch_size,self.args.n_agents, -1)
        avail_actions = batch["avail_actions"].view(self.args.batch_size, self.args.n_agents, -1)
        actions_onehot = batch["actions_onehot"].view(self.args.batch_size, -1)
        rewards = batch["reward"].view(self.args.batch_size, -1)
        next_obs = batch["next_obs"].view(self.args.batch_size, -1)
        next_avail_actions = batch["next_avail_actions"].view(self.args.batch_size, self.args.n_agents, -1)
        mask = batch["filled"].view(self.args.batch_size, -1)
        next_mask = batch["next_mask"].view(self.args.batch_size, -1)
        
        if self.args.standardise_rewards:
            self.rew_ms.update(rewards)
            rewards = (rewards - self.rew_ms.mean) / th.sqrt(self.rew_ms.var)
        if self.args.translate_rewards:
            rewards = self.rew_tr(rewards)
        
        # No experiences to train on in this minibatch
        if mask.sum() == 0:
            self.logger.log_stat("Mask_Sum_Zero", 1, t_env)
            self.logger.console_logger.error("Actor Critic Learner: mask.sum() == 0 at t_env {}".format(t_env))
            return
        
        if self.args.random_output:
            batch_sampler = Categorical(mask.view(-1))

        training_id = np.random.randint(2)
        
        mac_outputs = self.mac.forward(obs, id=training_id, refresh_hidden=True)
        target_mac_outputs = self.target_mac.forward(obs, id=1-training_id, refresh_hidden=True)
        critic_outputs = self.critic(states, actions_onehot, id=training_id)
        
        # Calculate critic's loss
        with th.no_grad():
            next_q = self.mac.forward(next_obs, id=training_id, refresh_hidden=True)
            next_q[next_avail_actions == 0.0] = -1e9
            prob = th.softmax(self.args.alpha * next_q, dim=-1)
            next_q = self.target_mac.forward(next_obs, id=1-training_id, refresh_hidden=True)
            next_v = th.mean(th.sum(th.nan_to_num(prob * (next_q - th.log(prob)/self.args.alpha)), dim=2), dim=1).unsqueeze(-1)
            if self.args.standardise_returns:
                next_v = next_v + self.ret_ms.mean
            next_v = next_v * next_mask
        
        returns = rewards + self.args.gamma * next_v
        if self.args.standardise_returns:
            self.ret_ms.update(returns, mask)
            returns = returns - self.ret_ms.mean
        
        td_error = (returns.detach() - critic_outputs)*mask
        critic_loss = th.sum(0.5 * td_error**2)/mask.sum()
        if self.args.random_output and np.random.rand() < 0.01:
            batch_id = batch_sampler.sample().long()
            print("learning global q ", states[batch_id], actions_onehot[batch_id], rewards[batch_id], next_v[batch_id], returns[batch_id], critic_outputs[batch_id])
        
        # Optimise critic
        self.critic_optimiser.zero_grad()
        critic_loss.backward()
        try:
            critic_grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params, self.args.grad_norm_clip, error_if_nonfinite=True)
        except Exception as e:
            print("Error! critic gradient not finite")
            print(critic_outputs, returns, next_v)
            raise(e)
        self.critic_optimiser.step()
        
        if self.args.translate_rewards:
            rew_loss = th.sum(0.5 * (returns*mask)**2)/mask.sum()
            self.rew_optimiser.zero_grad()
            rew_loss.backward()
            self.rew_optimiser.step()
        
        # Calculate agents' loss
        if self.training_steps > self.args.warmup_steps:
            mac_mask = mask.view(-1,1,1).repeat(1, self.n_agents, self.n_actions) * avail_actions
            onehot = OneHot(self.n_actions)

            with th.no_grad():
                chosen_actions = th.max(mac_outputs * avail_actions - 1e9 * (1 - avail_actions), dim=2, keepdim=True)[1]
                chosen_actions_onehot = onehot.transform(chosen_actions)
                critic_outputs = self.critic(states, chosen_actions_onehot.view(self.args.batch_size, -1), id=training_id).repeat(1,self.n_agents)    
            agent_outputs = th.gather(mac_outputs, dim=2, index=chosen_actions).squeeze(-1)
            if self.args.random_output and np.random.rand() < 0.01:
                batch_id = batch_sampler.sample().long()
                print("learning local v ",states[batch_id],mac_outputs[batch_id],chosen_actions[batch_id],agent_outputs[batch_id],critic_outputs[batch_id])
            agent_loss = th.sum((F.relu(agent_outputs - critic_outputs.detach())*mask))/mask.sum()
            
            with th.no_grad():
                chosen_action = self.target_mac.select_actions_by_scores(target_mac_outputs, avail_actions)
                chosen_actions = chosen_action.view(self.args.batch_size,1,1,self.n_agents).repeat(1,self.n_agents,self.n_actions,1)
                for i in range(self.n_agents):
                    for j in range(self.n_actions):
                        chosen_actions[:,i,j,i]=j
                chosen_actions_onehot = onehot.transform(chosen_actions.unsqueeze(-1))
                critic_outputs = self.critic(
                    states.repeat(1,self.n_agents*self.n_actions).view(self.args.batch_size*self.n_agents*self.n_actions,-1),
                    chosen_actions_onehot.view(self.args.batch_size*self.n_agents*self.n_actions,-1),
                    id=training_id
                ).view(-1,self.n_agents,self.n_actions)
            agent_outputs = mac_outputs
            if self.args.random_output and np.random.rand() < 0.01:
                batch_id = batch_sampler.sample().long()
                print("learning local q ",states[batch_id],chosen_action[batch_id],agent_outputs[batch_id],critic_outputs[batch_id])
            with th.no_grad():
                baselines = agent_outputs.detach()
                critic_outputs = th.exp(self.args.beta*(critic_outputs-baselines))
            agent_outputs = th.exp(self.args.beta*(agent_outputs-baselines))
            l2_loss = th.sum(0.5 * ((agent_outputs - critic_outputs.detach())*mac_mask)**2)/mac_mask.sum()/self.args.beta
            l2_loss /= self.args.beta
            
            agent_loss += l2_loss * self.args.agent_loss_scaler
            
            # Optimise agents
            self.agent_optimiser.zero_grad()
            agent_loss.backward()
            try:
                self.grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip, error_if_nonfinite=True)
            except Exception as e:
                print("Error! agent gradient not finite")
                print(critic_outputs, baselines)
                raise(e)
            self.agent_optimiser.step()
        else:
            agent_loss = th.tensor([0.0])
            l2_loss = th.tensor([0.0])
        
        self.training_steps += 1

        if self.args.target_update_interval_or_tau > 1 and (self.training_steps - self.last_target_update_step) / self.args.target_update_interval_or_tau >= 1.0:
            self._update_targets_hard()
            self.last_target_update_step = self.training_steps
        elif self.args.target_update_interval_or_tau <= 1.0:
            self._update_targets_soft(self.args.target_update_interval_or_tau)
        
        
        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("critic_loss", critic_loss.item(), t_env)
            self.logger.log_stat("critic_grad_norm", critic_grad_norm.item(), t_env)
            self.logger.log_stat("td_error_abs", td_error.abs().sum().item()/mask.sum().item(), t_env)
            self.logger.log_stat("agent_extra_loss", agent_loss.item()-l2_loss.item()*self.args.agent_loss_scaler, t_env)
            self.logger.log_stat("agent_l2_loss", l2_loss.item(), t_env)
            self.logger.log_stat("agent_grad_norm", self.grad_norm.item(), t_env)
            self.log_stats_t = t_env

    def _update_targets(self):
        self.target_mac.agent.load_state_dict(self.mac.agent.state_dict())
    
    def _update_targets_hard(self):
        self.target_mac.agent.load_state_dict(self.mac.agent.state_dict())
    
    def _update_targets_soft(self, tau):
        for target_param, param in zip(self.target_mac.parameters(), self.mac.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
    
    def cuda(self):
        self.mac.cuda()
        self.critic.cuda()
        self.target_mac.cuda()
    
    def save_models(self, path):
        self.mac.save_models(path)
        th.save(self.critic.state_dict(), "{}/critic.th".format(path))
        th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path))
        th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        self.critic.load_state_dict(th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage))
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        self.agent_optimiser.load_state_dict(th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage))
        self.critic_optimiser.load_state_dict(th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage))
