import copy
import numpy as np

import torch
import torch.nn as nn


def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    if module.bias is not None:
        bias_init(module.bias.data)
    return module


def get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def check(input):
    output = torch.from_numpy(input) if type(input) == np.ndarray else input
    return output


def exists(val):
    return val is not None


class Multi_Agent_Distribute:
    def __init__(self, distri_list, action_type):
        self.distri_list = distri_list
        self.num_agent = len(distri_list)
        self.mean = self.get_mean()
        self.covariance_matrix = self.get_covariance_matrix()
        self.probs = self.get_probs() if action_type == 'Discrete' else None

    def log_prob(self, v):
        multi_agent_log_prob = []
        for agent_i in range(self.num_agent):
            multi_agent_log_prob.append(
                self.distri_list[agent_i].log_prob(acts[:, agent_i, :])
            )
        multi_agent_log_prob = torch.stack(multi_agent_log_prob, dim=1)

        print('multi_agent_log_prob', multi_agent_log_prob.shape)

        return multi_agent_log_prob

    def get_mean(self):
        multi_mean = []
        for agent_i in range(self.num_agent):
            multi_mean.append(
                self.distri_list[agent_i].mean
            )
        multi_mean = torch.stack(multi_mean, dim=1)
        print('multi_mean', multi_mean.shape)

        return multi_mean

    def get_covariance_matrix(self):
        multi_covariance_matrix = []
        for agent_i in range(self.num_agent):
            multi_covariance_matrix.append(
                self.distri_list[agent_i].covariance_matrix
            )
        multi_covariance_matrix = torch.stack(multi_covariance_matrix, dim=1)
        print('multi_covariance_matrix', multi_covariance_matrix.shape)

        return multi_covariance_matrix

    def get_probs(self):
        multi_probs = []
        for agent_i in range(self.num_agent):
            multi_probs.append(
                self.distri_list[agent_i].probs
            )
        multi_probs = torch.stack(multi_probs, dim=1)
        print('multi_probs', multi_probs.shape)

        return multi_probs
