import copy

import torch
from components.episode_buffer import EpisodeBatch
from modules.mixers.vdn import VDNMixer
from modules.mixers.qmix import QMixer
from modules.mixers.qmix3 import QMixer3
from modules.mixers.qatten import QattenMixer
import torch as th
import numpy
from torch.optim import RMSprop, Adam
from modules.layers.act_layer import ActivateLayer
from torch.nn import Linear
from torch.nn import GRU

class QLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.params = list(mac.parameters())
        self.device = th.device('cuda' if args.use_cuda  else 'cpu')
        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "qatten":
                self.mixer = QattenMixer(args)
            elif args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = QMixer(args)
            elif args.mixer == "qmix3":
                self.mixer = QMixer3(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.params += list(self.mixer.parameters())
            if "no_mix" in args.name or "none" in args.name:
                self.target_mixer = self.mixer
            else:
                self.target_mixer = copy.deepcopy(self.mixer)

        self.neu_data = {}
        self.neu_avg = {}
        self.mask = {}
        self.spilt = {}
        self.overload = {}

        for (name, module) in self.mac.agent.named_modules():  # time need to be added
            if isinstance(module, ActivateLayer):
                self.mask[module.name] = th.ones_like(module.weight, device=args.device,dtype=th.int)
                self.overload[module.name] = 0
                self.spilt[module.name] = th.ones_like(module.weight, device=args.device,dtype=th.int)
                self.neu_data[module.name] = th.zeros_like(module.weight, device=args.device)
                def hook(module, fea_in, fea_out):
                    fea_out = fea_out.reshape(-1, args.n_agents, fea_out.shape[1]) # t(batch*agent, dim)  avg of each time
                    self.neu_data[module.name] = self.neu_data[module.name] + th.mean(fea_out, dim=0)
                    return None
                module.register_forward_hook(hook=hook)

        for (name, module) in self.mixer.named_modules():
            if isinstance(module, ActivateLayer):
                self.mask[module.name] = th.ones_like(module.weight, device=args.device, dtype=th.int)
                self.spilt[module.name] = th.ones_like(module.weight, device=args.device,dtype=th.int)
                self.neu_data[module.name] = th.zeros_like(module.weight, device=args.device)
                def hook(module, fea_in, fea_out):
                    self.neu_data[module.name] = self.neu_data[module.name] + th.mean(fea_out, dim=0) # 1(batch, dim)
                    return None
                module.register_forward_hook(hook=hook)

        para = sum([numpy.prod(list(p.size())) for p in self.mixer.parameters()])
        print(para * 4 / 1000 / 1000)

        self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        if "no_agent" in args.name or "none" in args.name:
            self.target_mac = mac
        else:
            self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int, check=False):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        for key in self.neu_data.keys():
            self.neu_data[key] = th.zeros_like(self.neu_data[key], device=self.args.device)

        # Calculate estimated Q-Values
        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3)  # Remove the last dim

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out[1:], dim=1)  # Concat across time

        # Mask out unavailable actions
        target_mac_out[avail_actions[:, 1:] == 0] = -9999999

        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions == 0] = -9999999
            cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]
            target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3)
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]

        # Mix
        if self.mixer is not None:
            chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1])
            target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:])

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals

        # Td-error
        td_error = (chosen_action_qvals - targets.detach())

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error ** 2).sum() / mask.sum()

        if check:
            self.find(t_env)
            self.optimiser.zero_grad()
            return

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env)
            self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env

    def find(self, t_env):
        tau = 0.1
        beta = 3
        if hasattr(self.args, 'tau'):
            tau = self.args.tau
        if hasattr(self.args, 'beta'):
            beta = self.args.beta
        for name, item in self.neu_data.items():
            if len(item.shape) == 2:
                item = item.sum(dim=0) # agent,dim / dim
            avg = item.mean()
            self.neu_avg[name] = avg
            self.mask[name][item <= tau * avg] = 0
            self.mask[name][item > tau * avg] = 1
            self.mask[name][item > beta * avg] = 2
            # self.spilt[name][self.mask[name]==0] = 1      # dead is not overload
            if th.count_nonzero((self.spilt[name]) & (item>3*avg))==0:   # find not overload
                self.spilt[name][:] = 1
            value = item.tolist()
            value = [round(num, 2) for num in value]
            if th.count_nonzero(self.spilt[name]) > 0:
                self.overload[name] = th.argmax(item * self.spilt[name])
                if item[self.overload[name]] < beta*avg:
                    self.overload[name] = -1
            else:
                self.overload[name] = -1

            count_01 = th.count_nonzero(item <= tau * avg).item()
            count_025 = th.count_nonzero(item <= 0.025 * avg).item()
            # count_001 = th.count_nonzero(item <= 0.01 * avg).item()
            print("dead_neural_%s" % (name), count_01, item.shape[0], t_env)
            print("neural_value_%s" % (name), value)
            # print("dead_neural_%s" % (name), count_001, sum, t_env)
                # self.logger.log_stat("dead_neural_%s%d" % (name, number), (count / sum), t_env)

    def masker(self):
        for (name, module) in self.mixer.named_modules():
            if isinstance(module, ActivateLayer):
                module.weight.data, self.mask[module.name] = self.mask[module.name], module.weight.data

        for (name, module) in self.mac.agent.named_modules():  # time need to be added
            if isinstance(module, ActivateLayer):
                module.weight.data, self.mask[module.name] = self.mask[module.name], module.weight.data

    def recycle(self):
        if '_all' in self.args.name:
            layers = list(self.mixer.named_modules()) + list(self.mac.agent.named_modules())
        else:
            layers = list(self.mixer.named_modules())
        exc = 0
        if '_u' in self.args.name:
            exc = 2
        with th.no_grad():
            for i in range(len(layers) - 2):
                act_layer = layers[i + 2][1]

                if isinstance(act_layer, ActivateLayer):
                    input_name, input_layer = layers[i]
                    output_name, output_layer = layers[i + 3]
                    layer_mask = self.mask[act_layer.name]
                    weight = input_layer.weight.data.T.clone()
                    bias = input_layer.bias.data.clone()

                    input_layer.reset_parameters()
                    # avg_weight = (th.matmul(weight, layer_mask)/th.count_nonzero(layer_mask)).reshape(-1, 1)
                    # avg_bias = (th.matmul(bias, layer_mask) / th.count_nonzero(layer_mask))
                    # # avg_weight = th.mean(weight, dim=1).reshape(-1, 1)
                    # # avg_bias = th.mean(bias)
                    #

                    input_layer.weight.data = th.where(layer_mask != exc, weight, input_layer.weight.data.T).T
                    input_layer.bias.data = th.where(layer_mask != exc, bias, input_layer.bias.data)

                    if isinstance(output_layer, Linear):
                        output_weight = output_layer.weight.data.T
                        output_weight[layer_mask == exc] = 0
                        # output_weight[layer_mask == 2] = 0
                        output_layer.weight.data = output_weight.T

                    layers[i] = (input_name, input_layer)
                    layers[i + 3] = (output_name, output_layer)

    def reborn(self):
        if '_all' in self.args.name:
            layers = list(self.mixer.named_modules()) + list(self.mac.agent.named_modules())
        else:
            layers = list(self.mixer.named_modules())
        k = 3
        with th.no_grad():
            for i in range(len(layers) - 3):
                act_layer = layers[i + 2][1]
                input_name, input_layer = layers[i]
                output_name, output_layer = layers[i + 3]

                if isinstance(act_layer, ActivateLayer) and isinstance(input_layer,Linear) and isinstance(output_layer,Linear):
                    layer_mask = self.mask[act_layer.name]
                    weight = input_layer.weight.data.T.clone()
                    bias = input_layer.bias.data.clone()
                    ids = th.where((layer_mask == 0))[0]
                    ids = ids[th.randperm(ids.size(0))]
                    dec = th.rand(ids.size(0), device=self.args.device)
                    dec[dec < 0.2] = 0.2
                    dec[dec > 0.8] = 0.8
                    idx = th.where((layer_mask == 2))[0]
                    idx = idx[th.randperm(idx.size(0))]

                    input_layer.reset_parameters()
                    input_layer.weight.data = th.where(layer_mask != 0, weight, input_layer.weight.data.T).T
                    input_layer.bias.data = th.where(layer_mask != 0, bias, input_layer.bias.data)

                    weight = input_layer.weight.data.T.clone()
                    bias = input_layer.bias.data.clone()
                    share_number = [numpy.random.randint(2,6) for i in range(idx.shape[0])]
                    for over in range(idx.shape[0]):
                        K = (ids.shape[0] - over - 1) // idx.shape[0] + 1
                        K = min(K, share_number[over])
                        for k in range(K):
                            dorm = over + k * idx.shape[0]
                            print(over, dorm, idx.shape[0], ids.shape[0])
                            # self.spilt[act_layer.name][ids[j]] = 0
                            weight[:, ids[dorm]] = weight[:, idx[over]] * dec[dorm]
                            bias[ids[dorm]] = bias[idx[over]] * dec[dorm]
                            # layer_mask[ids[j]] = 1
                    input_layer.weight.data = weight.T
                    input_layer.bias.data = bias

                    output_weight = output_layer.weight.data.T
                    output_weight[layer_mask == 0] = 0
                    for over in range(idx.shape[0]):
                        K = (ids.shape[0] - over - 1) // idx.shape[0] + 1
                        K = min(K, share_number[over])
                        emb = output_weight.shape[1]
                        div = th.softmax(th.randn(K + 1, emb, device=self.args.device), dim=0)
                        for k in range(K):
                            dorm = over + k * idx.shape[0]
                            output_weight[ids[dorm]] = output_weight[idx[over]] * div[k] / dec[dorm]

                        output_weight[idx[over]] = output_weight[idx[over]] * div[K]

                    output_layer.weight.data = output_weight.T

                    layers[i] = (input_name, input_layer)
                    layers[i + 3] = (output_name, output_layer)

    def delete(self):
        layers = list(self.mixer.named_modules())
        for i in range(len(layers)):
            act_name, act_layer = layers[i]
            if isinstance(act_layer, ActivateLayer):
                layer_mask = self.mask[act_layer.name]
                act_layer.weight.data = th.where(layer_mask == 0, 0, 1)
                layers[i] = (act_name, act_layer)

    def reset(self):
        if '_all' in self.args.name:
            layers = list(self.mixer.named_modules()) + list(self.mac.agent.named_modules())
        else:
            layers = list(self.mixer.named_modules())
        with th.no_grad():
            for i in range(len(layers)):
                mlp_name, mlp_layer = layers[i]
                if isinstance(mlp_layer, Linear):
                    if i == len(layers)-1:
                        mlp_layer.reset_parameters()
                        continue

                    next_layer = layers[i+1][1]
                    if not isinstance(next_layer, ActivateLayer) and not isinstance(next_layer, GRU):
                        mlp_layer.reset_parameters()

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda(self.args.device)
            self.target_mixer.cuda(self.args.device)

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
