import torch


class BaseCost(object):
    def __init__(self, tensor_kwargs={"device": "cpu", "dtype": torch.float32}):
        self.tensor_kwargs = tensor_kwargs

    def __call__(self, state, actions=None):
        return self.forward(state)

    def forward(self, state, actions=None):
        raise NotImplementedError()

    def reset(self, **kwargs):
        pass

    def init_iteration(self, **kwargs):
        pass


class CostSum(BaseCost):
    def __init__(self, costs, weights=None, remove_zeros=False,
                 tensor_kwargs={"device": "cpu", "dtype": torch.float32}):
        super(CostSum, self).__init__(tensor_kwargs)

        if weights is None:
            weights = [1.0 for _ in costs]

        self.costs = []
        self.weights = []

        # Removes any cost which has zero weight, for efficiency.
        if remove_zeros:
            for c, w in zip(costs, weights):
                if w != 0.0:
                    self.costs.append(c)
                    self.weights.append(w)
        else:
            self.costs = costs
            self.weights = weights

        self.weights = torch.as_tensor(self.weights).to(**tensor_kwargs)
        self.n_costs = len(self.costs)

    def forward(self, state):
        cost_vals = torch.stack([cost(state) for cost in self.costs], dim=-1)  # N x num_costs
        cost = (cost_vals * self.weights).sum(-1)
        return cost

    def reset(self, **kwargs):
        for i in range(self.n_costs):
            self.costs[i].reset(**kwargs)

    def init_iteration(self, **kwargs):
        for i in range(self.n_costs):
            self.costs[i].init_iteration(**kwargs)
