import torch
import torch.nn as nn
import numpy as np


def calculateTerm(values, termParameter, average=True):
    if termParameter == 0:
        termLoss = torch.mean(values)
    else:
        if average:
            additiveTerm = np.log(values.shape[0])
        else:
            additiveTerm = 0
        termLoss = 1 / termParameter * torch.logsumexp(termParameter * values - additiveTerm, dim=0)
    return termLoss


class Term:
    def __init__(self, numberOfLosses, termParameter, stochasticTerm=True, movingAverageLambda=0.9, averageTerm=True):
        self.numberOfTerms = numberOfLosses
        self.estimatesOnFullDataset = torch.zeros(self.numberOfTerms)
        self.termParameters = termParameter * torch.ones(self.numberOfTerms)
        self.stochasticTerm = stochasticTerm and termParameter != 0
        self.movingAverageLambda = movingAverageLambda
        self.initialized = not self.stochasticTerm
        self.averageTerm = averageTerm

        # The following are needed for proper initialization of the TERM. They will be deleted after initialization.
        self.minibatchSizes = [[] for _ in range(self.numberOfTerms)]
        self.minibatchLosses = [[] for _ in range(self.numberOfTerms)]
        self.mergeThreshold = 1000  # merge every this many minibatches
        self.minibatchCounter = 0

    @torch.no_grad()
    def initializeStochasticTerm(self, allLosses):
        assert len(allLosses) == self.numberOfTerms
        with torch.no_grad():
            for i in range(self.numberOfTerms):
                self.estimatesOnFullDataset[i] = calculateTerm(allLosses[i], self.termParameters[i], self.averageTerm)
        self.initialized = True

    @torch.no_grad()
    def mergeMinibatches(self):
        for i in range(self.numberOfTerms):
            if len(self.minibatchSizes[i]) == 0:
                continue
            if self.averageTerm:
                minibatchSizes = torch.tensor(self.minibatchSizes[i], device=self.minibatchLosses[i][0].device)
                sumOfElementsThusFar = torch.sum(minibatchSizes)
                minibatchSizes = minibatchSizes / sumOfElementsThusFar
                newLoss = torch.logsumexp(self.termParameters[i] * torch.stack(self.minibatchLosses[i])
                                          + minibatchSizes.log(), dim=0)
                self.minibatchSizes[i] = [sumOfElementsThusFar]
            else:
                newLoss = torch.logsumexp(self.termParameters[i] * torch.stack(self.minibatchLosses[i]), dim=0)

            self.minibatchLosses[i] = [newLoss]

    @torch.no_grad()
    def feedDataForInitialization(self, currentMiniBatchLosses):
        self.minibatchCounter += 1
        for i in range(self.numberOfTerms):
            if len(currentMiniBatchLosses[i]) > 0:
                self.minibatchLosses[i].append(calculateTerm(currentMiniBatchLosses[i], self.termParameters[i],
                                                             self.averageTerm))
                self.minibatchSizes[i].append(len(currentMiniBatchLosses[i]))
        if self.minibatchCounter >= self.mergeThreshold:
            self.mergeMinibatches()
            self.minibatchCounter = 0

    def concludeInitialization(self):
        self.mergeMinibatches()
        for i in range(self.numberOfTerms):
            if len(self.minibatchLosses[i]) == 0:
                raise ValueError("No data was provided from category %d" % i)
            self.estimatesOnFullDataset[i] = self.minibatchLosses[i][0]
        self.initialized = True
        del self.minibatchSizes
        del self.minibatchLosses
        del self.mergeThreshold
        del self.minibatchCounter

    def updateSpecificTerm(self, termIndex, lossValues):
        if not self.initialized:
            self.estimatesOnFullDataset = self.estimatesOnFullDataset.to(lossValues.device)
            self.initialized = True
        t = self.termParameters[termIndex]
        termLoss = calculateTerm(lossValues, t, self.averageTerm)
        if self.stochasticTerm:
            with torch.no_grad():
                if t == 0:
                    self.estimatesOnFullDataset[termIndex] = (
                            (1 - self.movingAverageLambda) * self.estimatesOnFullDataset[termIndex]
                            + self.movingAverageLambda * termLoss)
                else:
                    self.estimatesOnFullDataset[termIndex] = (
                            1 / t * torch.log((1 - self.movingAverageLambda)
                                              * torch.exp(t * self.estimatesOnFullDataset[termIndex])
                                              + self.movingAverageLambda * torch.exp(t * termLoss)))
            finalLoss = termLoss * torch.exp(t * (termLoss.detach() - self.estimatesOnFullDataset[termIndex]))

        else:
            finalLoss = termLoss
        return finalLoss



