from collections import defaultdict

from dreamcoder.frontier import *
from dreamcoder.program import *
from dreamcoder.type import *
from dreamcoder.utilities import *

import time

class GrammarFailure(Exception):
    pass

class SketchEnumerationFailure(Exception):
    pass

class NoCandidates(Exception):
    pass


class Grammar(object):
    def __init__(self, logVariable, productions, continuationType=None):
        self.logVariable = logVariable
        self.productions = productions

        self.continuationType = continuationType

        self.expression2likelihood = dict((p, l) for l, _, p in productions)
        self.expression2likelihood[Index(0)] = self.logVariable

    def randomWeights(self, r):
        """returns a new grammar with random weights drawn from r. calls `r` w/ old weight"""
        return Grammar(logVariable=r(self.logVariable),
                       productions=[(r(l),t,p)
                                    for l,t,p in self.productions ],
                       continuationType=self.continuationType)

    def strip_primitive_values(self):
        return Grammar(logVariable=self.logVariable,
                       productions=[(l,t,strip_primitive_values(p))
                                    for l,t,p in self.productions ],
                       continuationType=self.continuationType)

    def unstrip_primitive_values(self):
        return Grammar(logVariable=self.logVariable,
                       productions=[(l,t,unstrip_primitive_values(p))
                                    for l,t,p in self.productions ],
                       continuationType=self.continuationType)

    def __setstate__(self, state):
        """
        Legacy support for loading grammar objects without the imperative type filled in
        """
        assert 'logVariable' in state
        assert 'productions' in state
        if 'continuationType' in state:
            continuationType = state['continuationType']
        else:
            if any( 'turtle' in str(t) for l,t,p in state['productions'] ):
                continuationType = baseType("turtle")
            elif any( 'tower' in str(t) for l,t,p in state['productions'] ):
                continuationType = baseType("tower")
            else:
                continuationType = None
                
        self.__init__(state['logVariable'], state['productions'], continuationType=continuationType)

    @staticmethod
    def fromProductions(productions, logVariable=0.0, continuationType=None):
        """Make a grammar from primitives and their relative logpriors."""
        return Grammar(logVariable, [(l, p.infer(), p)
                                     for l, p in productions],
                       continuationType=continuationType)

    @staticmethod
    def uniform(primitives, continuationType=None):
        return Grammar(0.0, [(0.0, p.infer(), p) for p in primitives], continuationType=continuationType)

    def __len__(self): return len(self.productions)

    def __str__(self):
        def productionKey(xxx_todo_changeme):
            (l, t, p) = xxx_todo_changeme
            return not isinstance(p, Primitive), l is not None and -l
        if self.continuationType is not None:
            lines = ["continuation : %s"%self.continuationType]
        else:
            lines = []
        lines += ["%f\tt0\t$_" % self.logVariable]
        for l, t, p in sorted(self.productions, key=productionKey):
            if l is not None:
                l = "%f\t%s\t%s" % (l, t, p)
            else:
                l = "-Inf\t%s\t%s" % (t, p)
            if not t.isArrow() and isinstance(p, Invented):
                try:
                    l += "\teval = %s" % (p.evaluate([]))
                except BaseException:
                    pass

            lines.append(l)
        return "\n".join(lines)

    def json(self):
        j = {"logVariable": self.logVariable,
             "productions": [{"expression": str(p), "logProbability": l}
                             for l, _, p in self.productions]}
        if self.continuationType is not None:
            j["continuationType"] = self.continuationType.json()
        return j

    def _immutable_code(self): return self.logVariable, tuple(self.productions)

    def __eq__(self, o): return self._immutable_code() == o._immutable_code()

    def __ne__(self, o): return not (self == o)

    def __hash__(self): return hash(self._immutable_code())

    @property
    def primitives(self):
        return [p for _, _, p in self.productions]

    def removeProductions(self, ps):
        return Grammar(
            self.logVariable, [
                (l, t, p) for (
                    l, t, p) in self.productions if p not in ps],
            continuationType=self.continuationType)

    def buildCandidates(self, request, context, environment,
                        # Should the log probabilities be normalized?
                        normalize=True,
                        # Should be returned a table mapping primitives to
                        # their candidate entry?
                        returnTable=False,
                        # Should we return probabilities vs log probabilities?
                        returnProbabilities=False,
                        # Must be a leaf (have no arguments)?
                        mustBeLeaf=False):
        """Primitives that are candidates for being used given a requested type
        If returnTable is false (default): returns [((log)likelihood, tp, primitive, context)]
        if returntable is true: returns {primitive: ((log)likelihood, tp, context)}"""
        if returnProbabilities:
            assert normalize

        candidates = []
        variableCandidates = []
        for l, t, p in self.productions:
            try:
                newContext, t = t.instantiate(context)
                newContext = newContext.unify(t.returns(), request)
                t = t.apply(newContext)
                if mustBeLeaf and t.isArrow():
                    continue
                candidates.append((l, t, p, newContext))
            except UnificationFailure:
                continue
        for j, t in enumerate(environment):
            try:
                newContext = context.unify(t.returns(), request)
                t = t.apply(newContext)
                if mustBeLeaf and t.isArrow():
                    continue
                variableCandidates.append((t, Index(j), newContext))
            except UnificationFailure:
                continue

        if self.continuationType == request:
            terminalIndices = [v.i for t,v,k in variableCandidates if not t.isArrow()]
            if terminalIndices:
                smallestIndex = Index(min(terminalIndices))
                variableCandidates = [(t,v,k) for t,v,k in variableCandidates
                                      if t.isArrow() or v == smallestIndex]
            
        candidates += [(self.logVariable - log(len(variableCandidates)), t, p, k)
                       for t, p, k in variableCandidates]
        if candidates == []:
            raise NoCandidates()
        #eprint("candidates inside buildCandidates before norm:")
        #eprint(candidates)

        if normalize:
            z = lse([l for l, t, p, k in candidates])
            if returnProbabilities:
                candidates = [(exp(l - z), t, p, k)
                              for l, t, p, k in candidates]
            else:
                candidates = [(l - z, t, p, k) for l, t, p, k in candidates]

        #eprint("candidates inside buildCandidates after norm:")
        #eprint(candidates)

        if returnTable:
            return {p: (l, t, k) for l, t, p, k in candidates}
        else:
            return candidates


    def sample(self, request, maximumDepth=6, maxAttempts=None):
        attempts = 0

        while True:
            try:
                _, e = self._sample(
                    request, Context.EMPTY, [], maximumDepth=maximumDepth)
                return e
            except NoCandidates:
                if maxAttempts is not None:
                    attempts += 1
                    if attempts > maxAttempts:
                        return None
                continue

    def _sample(self, request, context, environment, maximumDepth):
        if request.isArrow():
            context, expression = self._sample(
                request.arguments[1], context, [
                    request.arguments[0]] + environment, maximumDepth)
            return context, Abstraction(expression)

        candidates = self.buildCandidates(request, context, environment,
                                          normalize=True,
                                          returnProbabilities=True,
                                          # Force it to terminate in a
                                          # leaf; a primitive with no
                                          # function arguments
                                          mustBeLeaf=maximumDepth <= 1)
        #eprint("candidates:")
        #eprint(candidates)
        newType, chosenPrimitive, context = sampleDistribution(candidates)

        # Sample the arguments
        xs = newType.functionArguments()
        returnValue = chosenPrimitive

        for x in xs:
            x = x.apply(context)
            context, x = self._sample(x, context, environment, maximumDepth - 1)
            returnValue = Application(returnValue, x)

        return context, returnValue

    def likelihoodSummary(self, context, environment, request, expression, silent=False):
        if request.isArrow():
            if not isinstance(expression, Abstraction):
                if not silent:
                    eprint("Request is an arrow but I got", expression)
                return context, None
            return self.likelihoodSummary(context,
                                          [request.arguments[0]] + environment,
                                          request.arguments[1],
                                          expression.body,
                                          silent=silent)
        # Build the candidates
        candidates = self.buildCandidates(request, context, environment,
                                          normalize=False,
                                          returnTable=True)

        # A list of everything that would have been possible to use here
        possibles = [p for p in candidates.keys() if not p.isIndex]
        numberOfVariables = sum(p.isIndex for p in candidates.keys())
        if numberOfVariables > 0:
            possibles += [Index(0)]

        f, xs = expression.applicationParse()

        if f not in candidates:
            if self.continuationType is not None and f.isIndex:
                ls = LikelihoodSummary()
                ls.constant = NEGATIVEINFINITY
                return ls
            
            if not silent:
                eprint(f, "Not in candidates")
                eprint("Candidates is", candidates)
                #eprint("grammar:", grammar.productions)
                eprint("request is", request)
                eprint("xs", xs)
                eprint("environment", environment)
                assert False
            return context, None

        thisSummary = LikelihoodSummary()
        thisSummary.record(f, possibles,
                           constant= -math.log(numberOfVariables) if f.isIndex else 0)

        _, tp, context = candidates[f]
        argumentTypes = tp.functionArguments()
        if len(xs) != len(argumentTypes):
            eprint("PANIC: not enough arguments for the type")
            eprint("request", request)
            eprint("tp", tp)
            eprint("expression", expression)
            eprint("xs", xs)
            eprint("argumentTypes", argumentTypes)
            # This should absolutely never occur
            raise GrammarFailure((context, environment, request, expression))

        for argumentType, argument in zip(argumentTypes, xs):
            argumentType = argumentType.apply(context)
            context, newSummary = self.likelihoodSummary(
                context, environment, argumentType, argument, silent=silent)
            if newSummary is None:
                return context, None
            thisSummary.join(newSummary)

        return context, thisSummary

    def bestFirstEnumeration(self, request):
        from heapq import heappush, heappop

        pq = []

        def choices(parentCost, xs):
            for c, x in xs:
                heappush(pq, (parentCost + c, x))

        def g(parentCost, request, _=None,
              context=None, environment=[],
              k=None):
            """
            k is a continuation.
            k: Expects to be called with MDL, context, expression.
            """

            assert k is not None
            if context is None:
                context = Context.EMPTY

            if request.isArrow():
                g(parentCost,
                  request.arguments[1],
                  context=context,
                  environment=[request.arguments[0]] + environment,
                    k=lambda MDL,
                    newContext,
                    p: k(MDL,
                         newContext,
                         Abstraction(p)))
            else:
                candidates = self.buildCandidates(request,
                                                  context,
                                                  environment,
                                                  normalize=True,
                                                  returnProbabilities=False,
                                                  returnTable=True)
                choices(parentCost,
                        [(-f_ll_tp_newContext[1][0],
                          lambda: ga(parentCost - f_ll_tp_newContext[1][0],
                                     f_ll_tp_newContext[0],
                                     f_ll_tp_newContext[1][1].functionArguments(),
                                     context=f_ll_tp_newContext[1][2],
                                     environment=environment,
                                     k=k)) for f_ll_tp_newContext in iter(candidates.items())])

        def ga(costSoFar, f, argumentTypes, _=None,
               context=None, environment=None,
               k=None):
            if argumentTypes == []:
                k(costSoFar, context, f)
            else:
                t1 = argumentTypes[0].apply(context)
                g(costSoFar, t1, context=context, environment=environment,
                  k=lambda newCost, newContext, argument:
                  ga(newCost, Application(f, argument), argumentTypes[1:],
                     context=newContext, environment=environment,
                     k=k))

        def receiveResult(MDL, _, expression):
            heappush(pq, (MDL, expression))

        g(0., request, context=Context.EMPTY, environment=[], k=receiveResult)
        frontier = []
        while len(frontier) < 10**3:
            MDL, action = heappop(pq)
            if isinstance(action, Program):
                expression = action
                frontier.append(expression)
                #eprint("Enumerated program",expression,-MDL,self.closedLogLikelihood(request, expression))
            else:
                action()

    def closedLikelihoodSummary(self, request, expression, silent=False):
        try:
            context, summary = self.likelihoodSummary(Context.EMPTY, [], request, expression, silent=silent)
        except GrammarFailure as e:
            failureExport = 'failures/grammarFailure%s.pickle' % (
                time.time() + getPID())
            eprint("PANIC: Grammar failure, exporting to ", failureExport)
            with open(failureExport, 'wb') as handle:
                pickle.dump((e, self, request, expression), handle)
            assert False

        return summary

    def logLikelihood(self, request, expression):
        summary = self.closedLikelihoodSummary(request, expression)
        if summary is None:
            eprint(
                "FATAL: program [ %s ] does not have a likelihood summary." %
                expression, "r = ", request, "\n", self)
            assert False
        return summary.logLikelihood(self)

    def rescoreFrontier(self, frontier):
        return Frontier([FrontierEntry(e.program,
                                       logPrior=self.logLikelihood(frontier.task.request, e.program),
                                       logLikelihood=e.logLikelihood)
                         for e in frontier],
                        frontier.task)

    def productionUses(self, frontiers):
        """Returns the expected number of times that each production was used. {production: expectedUses}"""
        frontiers = [self.rescoreFrontier(f).normalize()
                     for f in frontiers if not f.empty]
        uses = {p: 0. for p in self.primitives}
        for f in frontiers:
            for e in f:
                summary = self.closedLikelihoodSummary(f.task.request,
                                                       e.program)
                for p, u in summary.uses:
                    uses[p] += u * math.exp(e.logPosterior)
        return uses

    def insideOutside(self, frontiers, pseudoCounts, iterations=1):
        # Replace programs with (likelihood summary, uses)
        frontiers = [ Frontier([ FrontierEntry((summary, summary.toUses()),
                                               logPrior=summary.logLikelihood(self),
                                               logLikelihood=e.logLikelihood)
                                 for e in f
                                 for summary in [self.closedLikelihoodSummary(f.task.request, e.program)] ],
                               task=f.task)
                      for f in frontiers ]

        g = self
        for i in range(iterations):
            u = Uses()
            for f in frontiers:
                f = f.normalize()
                for e in f:
                    _, eu = e.program
                    u += math.exp(e.logPosterior) * eu

            lv = math.log(u.actualVariables + pseudoCounts) - \
                 math.log(u.possibleVariables + pseudoCounts)
            g = Grammar(lv,
                        [ (math.log(u.actualUses.get(p,0.) + pseudoCounts) - \
                           math.log(u.possibleUses.get(p,0.) + pseudoCounts),
                           t,p)
                          for _,t,p in g.productions ],
                        continuationType=self.continuationType)
            if i < iterations - 1:
                frontiers = [Frontier([ FrontierEntry((summary, uses),
                                                      logPrior=summary.logLikelihood(g),
                                                      logLikelihood=e.logLikelihood)
                                        for e in f
                                        for (summary, uses) in [e.program] ],
                                      task=f.task)
                             for f in frontiers ]
        return g

    def frontierMDL(self, frontier):
        return max( e.logLikelihood + self.logLikelihood(frontier.task.request, e.program)
                    for e in frontier )                


    def enumeration(self,context,environment,request,upperBound,
                    maximumDepth=20,
                    lowerBound=0.):
        '''Enumerates all programs whose MDL satisfies: lowerBound <= MDL < upperBound'''
        if upperBound < 0 or maximumDepth == 1:
            return

        if request.isArrow():
            v = request.arguments[0]
            for l, newContext, b in self.enumeration(context, [v] + environment,
                                                     request.arguments[1],
                                                     upperBound=upperBound,
                                                     lowerBound=lowerBound,
                                                     maximumDepth=maximumDepth):
                yield l, newContext, Abstraction(b)

        else:
            candidates = self.buildCandidates(request, context, environment,
                                              normalize=True)

            for l, t, p, newContext in candidates:
                mdl = -l
                if not (mdl < upperBound):
                    continue

                xs = t.functionArguments()
                for aL, aK, application in\
                    self.enumerateApplication(newContext, environment, p, xs,
                                              upperBound=upperBound + l,
                                              lowerBound=lowerBound + l,
                                              maximumDepth=maximumDepth - 1):
                    yield aL + l, aK, application

    def enumerateApplication(self, context, environment,
                             function, argumentRequests,
                             # Upper bound on the description length of all of
                             # the arguments
                             upperBound,
                             # Lower bound on the description length of all of
                             # the arguments
                             lowerBound=0.,
                             maximumDepth=20,
                             originalFunction=None,
                             argumentIndex=0):
        if upperBound < 0. or maximumDepth == 1:
            return
        if originalFunction is None:
            originalFunction = function

        if argumentRequests == []:
            if lowerBound <= 0. and 0. < upperBound:
                yield 0., context, function
            else:
                return
        else:
            argRequest = argumentRequests[0].apply(context)
            laterRequests = argumentRequests[1:]
            for argL, newContext, arg in self.enumeration(context, environment, argRequest,
                                                          upperBound=upperBound,
                                                          lowerBound=0.,
                                                          maximumDepth=maximumDepth):
                if violatesSymmetry(originalFunction, arg, argumentIndex):
                    continue

                newFunction = Application(function, arg)
                for resultL, resultK, result in self.enumerateApplication(newContext, environment, newFunction,
                                                                          laterRequests,
                                                                          upperBound=upperBound + argL,
                                                                          lowerBound=lowerBound + argL,
                                                                          maximumDepth=maximumDepth,
                                                                          originalFunction=originalFunction,
                                                                          argumentIndex=argumentIndex + 1):
                    yield resultL + argL, resultK, result

    def sketchEnumeration(self,context,environment,request,sk,upperBound,
                           maximumDepth=20,
                           lowerBound=0.):
        '''Enumerates all sketch instantiations whose MDL satisfies: lowerBound <= MDL < upperBound'''
        if upperBound < 0. or maximumDepth == 1:
            return

        if sk.isHole:
            yield from self.enumeration(context, environment, request, upperBound,
                                        maximumDepth=maximumDepth,
                                        lowerBound=lowerBound)
        elif request.isArrow():
            assert sk.isAbstraction
            v = request.arguments[0]
            for l, newContext, b in self.sketchEnumeration(context, [v] + environment,
                                                           request.arguments[1],
                                                           sk.body,
                                                           upperBound=upperBound,
                                                           lowerBound=lowerBound,
                                                           maximumDepth=maximumDepth):
                yield l, newContext, Abstraction(b)

        else:
            f, xs = sk.applicationParse()
            if f.isIndex:
                ft = environment[f.i].apply(context)
            elif f.isInvented or f.isPrimitive:
                context, ft = f.tp.instantiate(context)
            elif f.isAbstraction:
                assert False, "sketch is not in beta longform"
            elif f.isHole:
                assert False, "hole as function not yet supported"
            elif f.isApplication:
                assert False, "should never happen - bug in applicationParse"
            else: assert False

            try: context = context.unify(ft.returns(), request)                
            except UnificationFailure:
                print("Exception: sketch is ill-typed")
                return #so that we can continue evaluating
                # raise SketchEnumerationFailure() #"sketch is ill-typed"
            ft = ft.apply(context)
            argumentRequests = ft.functionArguments()

            assert len(argumentRequests) == len(xs)

            yield from self.sketchApplication(context, environment,
                                              f, xs, argumentRequests,
                                              upperBound=upperBound,
                                              lowerBound=lowerBound,
                                              maximumDepth=maximumDepth - 1)


    def sketchApplication(self, context, environment,
                          function, arguments, argumentRequests,
                          # Upper bound on the description length of all of
                          # the arguments
                          upperBound,
                          # Lower bound on the description length of all of
                          # the arguments
                          lowerBound=0.,
                          maximumDepth=20):
        if upperBound < 0. or maximumDepth == 1:
            return

        if argumentRequests == []:
            if lowerBound <= 0. and 0. < upperBound:
                yield 0., context, function
            else:
                return
        else:
            argRequest = argumentRequests[0].apply(context)
            laterRequests = argumentRequests[1:]
            firstSketch = arguments[0]
            laterSketches = arguments[1:]
            for argL, newContext, arg in self.sketchEnumeration(context, environment, argRequest,
                                                                firstSketch,
                                                                upperBound=upperBound,
                                                                lowerBound=0.,
                                                                maximumDepth=maximumDepth):

                newFunction = Application(function, arg)
                for resultL, resultK, result in self.sketchApplication(newContext, environment, newFunction,
                                                                       laterSketches, laterRequests,
                                                                       upperBound=upperBound + argL,
                                                                       lowerBound=lowerBound + argL,
                                                                       maximumDepth=maximumDepth):

                    yield resultL + argL, resultK, result

    def sketchLogLikelihood(self, request, full, sk, context=Context.EMPTY, environment=[]):
        """
        calculates mdl of full program 'full' from sketch 'sk'
        """
        if sk.isHole:
            _, summary = self.likelihoodSummary(context, environment, request, full)
            if summary is None:
                eprint(
                    "FATAL: program [ %s ] does not have a likelihood summary." %
                    full, "r = ", request, "\n", self)
                assert False
            return summary.logLikelihood(self), context

        elif request.isArrow():
            assert sk.isAbstraction and full.isAbstraction
            #assert sk.f == full.f #is this right? or do i need to recurse?
            v = request.arguments[0]
            return self.sketchLogLikelihood(request.arguments[1], full.body, sk.body, context=context, environment=[v] + environment)

        else:
            sk_f, sk_xs = sk.applicationParse()
            full_f, full_xs = full.applicationParse()
            if sk_f.isIndex:
                assert sk_f == full_f, "sketch and full program don't match on an index"
                ft = environment[sk_f.i].apply(context)
            elif sk_f.isInvented or sk_f.isPrimitive:
                assert sk_f == full_f, "sketch and full program don't match on a primitive"
                context, ft = sk_f.tp.instantiate(context)
            elif sk_f.isAbstraction:
                assert False, "sketch is not in beta longform"
            elif sk_f.isHole:
                assert False, "hole as function not yet supported"
            elif sk_f.isApplication:
                assert False, "should never happen - bug in applicationParse"
            else: assert False

            try: context = context.unify(ft.returns(), request)                
            except UnificationFailure: assert False, "sketch is ill-typed"
            ft = ft.apply(context)
            argumentRequests = ft.functionArguments()

            assert len(argumentRequests) == len(sk_xs) == len(full_xs)  #this might not be true if holes??

            return self.sketchllApplication(context, environment,
                                              sk_f, sk_xs, full_f, full_xs, argumentRequests)

    def sketchllApplication(self, context, environment,
                          sk_function, sk_arguments, full_function, full_arguments, argumentRequests):
        if argumentRequests == []:
                return torch.tensor([0.]).cuda(), context #does this make sense?
        else:
            argRequest = argumentRequests[0].apply(context)
            laterRequests = argumentRequests[1:]

            sk_firstSketch = sk_arguments[0]
            full_firstSketch = full_arguments[0]
            sk_laterSketches = sk_arguments[1:]
            full_laterSketches = full_arguments[1:]

            argL, newContext = self.sketchLogLikelihood(argRequest, full_firstSketch, sk_firstSketch, context=context, environment=environment)

            #finish this...
            sk_newFunction = Application(sk_function, sk_firstSketch)  # is this redundant? maybe 
            full_newFunction = Application(full_function, full_firstSketch)

            resultL, context = self.sketchllApplication(newContext, environment, sk_newFunction, sk_laterSketches,
                                            full_newFunction, full_laterSketches, laterRequests)

            return resultL + argL, context

        
    def enumerateNearby(self, request, expr, distance=3.0):
        """Enumerate programs with local mutations in subtrees with small description length"""
        if distance <= 0:
            yield expr
        else:
            def mutations(tp, loss):
                for l, _, expr in self.enumeration(
                        Context.EMPTY, [], tp, distance - loss):
                    yield expr, l
            yield from Mutator(self, mutations).execute(expr, request)


    def enumerateHoles(self, request, expr, k=3, return_obj=Hole):
        """Enumerate programs with a single hole within mdl distance"""
        #TODO: make it possible to enumerate sketches with multiple holes
        def mutations(tp, loss, is_left_application=False):
            """
            to allow applications lhs to become a hole,  
            remove the condition below and ignore all the is_left_application kwds 
            """
            if not is_left_application: 
                yield return_obj(), 0
        top_k = []
        for expr, l in Mutator(self, mutations).execute(expr, request):
            if len(top_k) > 0:
                i, v = min(enumerate(top_k), key=lambda x:x[1][1])
                if l > v[1]:
                    if len(top_k) >= k:
                        top_k[i] = (expr, l)
                    else:
                        top_k.append((expr, l))
                elif len(top_k) < k:
                    top_k.append((expr, l))
            else:
                top_k.append((expr, l))
        return sorted(top_k, key=lambda x:-x[1])

    def untorch(self):
        return Grammar(self.logVariable.data.tolist()[0], 
                       [ (l.data.tolist()[0], t, p)
                         for l, t, p in self.productions],
                       continuationType=self.continuationType)

class LikelihoodSummary(object):
    '''Summarizes the terms that will be used in a likelihood calculation'''

    def __init__(self):
        self.uses = {}
        self.normalizers = {}
        self.constant = 0.

    def __str__(self):
        return """LikelihoodSummary(constant = %f,
uses = {%s},
normalizers = {%s})""" % (self.constant,
                          ", ".join(
                              "%s: %d" % (k,
                                          v) for k,
                              v in self.uses.items()),
                          ", ".join(
                              "%s: %d" % (k,
                                          v) for k,
                              v in self.normalizers.items()))

    def record(self, actual, possibles, constant=0.):
        # Variables are all normalized to be $0
        if isinstance(actual, Index):
            actual = Index(0)

        # Make it something that we can hash
        possibles = frozenset(sorted(possibles, key=hash))

        self.constant += constant
        self.uses[actual] = self.uses.get(actual, 0) + 1
        self.normalizers[possibles] = self.normalizers.get(possibles, 0) + 1

    def join(self, other):
        self.constant += other.constant
        for k, v in other.uses.items():
            self.uses[k] = self.uses.get(k, 0) + v
        for k, v in other.normalizers.items():
            self.normalizers[k] = self.normalizers.get(k, 0) + v

    def logLikelihood(self, grammar):
        return self.constant + \
            sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items()) - \
            sum(count * lse([grammar.expression2likelihood[p] for p in ps])
                for ps, count in self.normalizers.items())
    def logLikelihood_overlyGeneral(self, grammar):
        """Calculates log likelihood of this summary, given that the summary might refer to productions that don't occur in the grammar"""
        return self.constant + \
            sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items()) - \
            sum(count * lse([grammar.expression2likelihood.get(p,NEGATIVEINFINITY) for p in ps])
                for ps, count in self.normalizers.items())        
    def numerator(self, grammar):
        return self.constant + \
            sum(count * grammar.expression2likelihood[p] for p, count in self.uses.items())
    def denominator(self, grammar):
        return \
            sum(count * lse([grammar.expression2likelihood[p] for p in ps])
                for ps, count in self.normalizers.items())
    def toUses(self):
        from collections import Counter
        
        possibleVariables = sum( count if Index(0) in ps else 0
                                 for ps, count in self.normalizers.items() )
        actualVariables = self.uses.get(Index(0), 0.)
        actualUses = {k: v
                      for k, v in self.uses.items()
                      if not k.isIndex }
        possibleUses = dict(Counter(p
                                    for ps, count in self.normalizers.items()
                                    for p_ in ps
                                    if not p_.isIndex
                                    for p in [p_]*count ))
        return Uses(possibleVariables, actualVariables,
                    possibleUses, actualUses)


class Uses(object):
    '''Tracks uses of different grammar productions'''

    def __init__(self, possibleVariables=0., actualVariables=0.,
                 possibleUses={}, actualUses={}):
        self.actualVariables = actualVariables
        self.possibleVariables = possibleVariables
        self.possibleUses = possibleUses
        self.actualUses = actualUses

    def __str__(self):
        return "Uses(actualVariables = %f, possibleVariables = %f, actualUses = %s, possibleUses = %s)" %\
            (self.actualVariables, self.possibleVariables, self.actualUses, self.possibleUses)

    def __repr__(self): return str(self)

    def __mul__(self, a):
        return Uses(a * self.possibleVariables,
                    a * self.actualVariables,
                    {p: a * u for p, u in self.possibleUses.items()},
                    {p: a * u for p, u in self.actualUses.items()})

    def __imul__(self, a):
        self.possibleVariables *= a
        self.actualVariables *= a
        for p in self.possibleUses:
            self.possibleUses[p] *= a
        for p in self.actualUses:
            self.actualUses[p] *= a
        return self

    def __rmul__(self, a):
        return self * a

    def __radd__(self, o):
        if o == 0:
            return self
        return self + o

    def __add__(self, o):
        if o == 0:
            return self

        def merge(x, y):
            z = x.copy()
            for k, v in y.items():
                z[k] = v + x.get(k, 0.)
            return z
        return Uses(self.possibleVariables + o.possibleVariables,
                    self.actualVariables + o.actualVariables,
                    merge(self.possibleUses, o.possibleUses),
                    merge(self.actualUses, o.actualUses))

    def __iadd__(self, o):
        self.possibleVariables += o.possibleVariables
        self.actualVariables += o.actualVariables
        for k, v in o.possibleUses.items():
            self.possibleUses[k] = self.possibleUses.get(k, 0.) + v
        for k, v in o.actualUses.items():
            self.actualUses[k] = self.actualUses.get(k, 0.) + v
        return self

    @staticmethod
    def join(z, *weightedUses):
        """Consumes weightedUses"""
        if not weightedUses:
            Uses.empty
        if len(weightedUses) == 1:
            return weightedUses[0][1]
        for w, u in weightedUses:
            u *= exp(w - z)
        total = Uses()
        total.possibleVariables = sum(
            u.possibleVariables for _, u in weightedUses)
        total.actualVariables = sum(u.actualVariables for _, u in weightedUses)
        total.possibleUses = defaultdict(float)
        total.actualUses = defaultdict(float)
        for _, u in weightedUses:
            for k, v in u.possibleUses.items():
                total.possibleUses[k] += v
            for k, v in u.actualUses.items():
                total.actualUses[k] += v
        return total


Uses.empty = Uses()

class ContextualGrammar:
    def __init__(self, noParent, variableParent, library):
        assert isinstance(noParent, Grammar)
        assert isinstance(variableParent, Grammar)
        for e, gs in library.items():
            for g in gs:
                assert isinstance(g, Grammar)
        self.noParent, self.variableParent, self.library = noParent, variableParent, library

        self.productions = [(None,t,p) for _,t,p in self.noParent.productions ]
        self.primitives = [p for _,_2,p in self.productions ]

        self.continuationType = noParent.continuationType
        assert variableParent.continuationType == self.continuationType

        assert set(noParent.primitives) == set(variableParent.primitives)
        assert set(variableParent.primitives) == set(library.keys())
        for e,gs in library.items():
            assert len(gs) == len(e.infer().functionArguments())
            for g in gs:
                assert set(g.primitives) == set(library.keys())
                assert g.continuationType == self.continuationType

    def untorch(self):
        return ContextualGrammar(self.noParent.untorch(), self.variableParent.untorch(),
                                 {e: [g.untorch() for g in gs ]
                                  for e,gs in self.library.items() })

    def randomWeights(self, r):
        """returns a new grammar with random weights drawn from r. calls `r` w/ old weight"""
        return ContextualGrammar(self.noParent.randomWeights(r),
                                 self.variableParent.randomWeights(r),
                                 {e: [g.randomWeights(r) for g in gs]
                                  for e,gs in self.library.items() })
    def __str__(self):
        lines = ["No parent:",str(self.noParent),"",
                 "Variable parent:",str(self.variableParent),"",
                 ""]
        for e,gs in self.library.items():
            for j,g in enumerate(gs):
                lines.extend(["Parent %s, argument index %s"%(e,j),
                              str(g),
                              ""])
        return "\n".join(lines)

    def json(self):
        return {"noParent": self.noParent.json(),
                "variableParent": self.variableParent.json(),
                "productions": [{"program": str(e),
                                 "arguments": [gp.json() for gp in gs ]}
                                    for e,gs in self.library.items() ]}

    def __len__(self):
        return len(self.noParent) # should be the same for all, per assertions in constructor

    @staticmethod
    def fromGrammar(g):
        if isinstance(g, ContextualGrammar):
            return g
        return ContextualGrammar(g, g,
                                 {e: [g]*len(e.infer().functionArguments())
                                  for e in g.primitives })
                

    class LS: # likelihood summary
        def __init__(self, owner):
            self.noParent = LikelihoodSummary()
            self.variableParent = LikelihoodSummary()
            self.library = {e: [LikelihoodSummary() for _ in gs]  for e,gs in owner.library.items() }

        def record(self, parent, parentIndex, actual, possibles, constant):
            if parent is None: ls = self.noParent
            elif parent.isIndex: ls = self.variableParent
            else: ls = self.library[parent][parentIndex]
            ls.record(actual, possibles, constant=constant)

        def join(self, other):
            self.noParent.join(other.noParent)
            self.variableParent.join(other.variableParent)
            for e,gs in self.library.items():
                for g1,g2 in zip(gs, other.library[e]):
                    g1.join(g2)

        def logLikelihood(self, owner):
            return self.noParent.logLikelihood(owner.noParent) + \
                   self.variableParent.logLikelihood(owner.variableParent) + \
                   sum(r.logLikelihood(g)
                       for e, rs in self.library.items()
                       for r,g in zip(rs, owner.library[e]) )            
        def numerator(self, owner):
            return self.noParent.numerator(owner.noParent) + \
                   self.variableParent.numerator(owner.variableParent) + \
                   sum(r.numerator(g)
                       for e, rs in self.library.items()
                       for r,g in zip(rs, owner.library[e]) )            
        def denominator(self, owner):
            return self.noParent.denominator(owner.noParent) + \
                   self.variableParent.denominator(owner.variableParent) + \
                   sum(r.denominator(g)
                       for e, rs in self.library.items()
                       for r,g in zip(rs, owner.library[e]) )            

    def likelihoodSummary(self, parent, parentIndex, context, environment, request, expression):
        if request.isArrow():
            assert expression.isAbstraction
            return self.likelihoodSummary(parent, parentIndex,
                                          context,
                                          [request.arguments[0]] + environment,
                                          request.arguments[1],
                                          expression.body)
        if parent is None: g = self.noParent
        elif parent.isIndex: g = self.variableParent
        else: g = self.library[parent][parentIndex]            
        candidates = g.buildCandidates(request, context, environment,
                                       normalize=False, returnTable=True)

        # A list of everything that would have been possible to use here
        possibles = [p for p in candidates.keys() if not p.isIndex]
        numberOfVariables = sum(p.isIndex for p in candidates.keys())
        if numberOfVariables > 0:
            possibles += [Index(0)]

        f, xs = expression.applicationParse()

        assert f in candidates

        thisSummary = self.LS(self)
        thisSummary.record(parent, parentIndex,
                           f, possibles,
                           constant= -math.log(numberOfVariables) if f.isIndex else 0)

        _, tp, context = candidates[f]
        argumentTypes = tp.functionArguments()
        assert len(xs) == len(argumentTypes)

        for i, (argumentType, argument) in enumerate(zip(argumentTypes, xs)):
            argumentType = argumentType.apply(context)
            context, newSummary = self.likelihoodSummary(f, i,
                                                         context, environment, argumentType, argument)
            thisSummary.join(newSummary)

        return context, thisSummary

    def closedLikelihoodSummary(self, request, expression):
        return self.likelihoodSummary(None,None,
                                      Context.EMPTY,[],
                                      request, expression)[1]

    def logLikelihood(self, request, expression):
        return self.closedLikelihoodSummary(request, expression).logLikelihood(self)

    def sample(self, request, maximumDepth=8, maxAttempts=None):
        attempts = 0
        while True:
            try:
                _, e = self._sample(None, None, Context.EMPTY, [], request, maximumDepth)
                return e
            except NoCandidates:
                if maxAttempts is not None:
                    attempts += 1
                    if attempts > maxAttempts: return None
                continue
            
    def _sample(self, parent, parentIndex, context, environment, request, maximumDepth):
        if request.isArrow():
            context, body = self._sample(parent, parentIndex, context,
                                         [request.arguments[0]] + environment,
                                         request.arguments[1],
                                         maximumDepth)
            return context, Abstraction(body)
        if parent is None: g = self.noParent
        elif parent.isIndex: g = self.variableParent
        else: g = self.library[parent][parentIndex]
        candidates = g.buildCandidates(request, context, environment,
                                       normalize=True, returnProbabilities=True,
                                       mustBeLeaf=(maximumDepth <= 1))
        newType, chosenPrimitive, context = sampleDistribution(candidates)

        xs = newType.functionArguments()
        returnValue = chosenPrimitive

        for j,x in enumerate(xs):
            x = x.apply(context)
            context, x = self._sample(chosenPrimitive, j, context, environment, x, maximumDepth - 1)
            returnValue = Application(returnValue, x)
            
        return context, returnValue

    def expectedUsesMonteCarlo(self, request, debug=None):
        import numpy as np
        n = 0
        u = [0.]*len(self.primitives)
        primitives = list(sorted(self.primitives, key=str))
        noInventions = all( not p.isInvented for p in primitives )
        primitive2index = {primitive: i
                           for i, primitive in enumerate(primitives)
                           if primitive.isInvented or noInventions }
        eprint(primitive2index)
        ns = 10000
        with timing(f"calculated expected uses using Monte Carlo simulation w/ {ns} samples"):
            for _ in range(ns):
                p = self.sample(request, maxAttempts=0)
                if p is None: continue
                n += 1
                if debug and n < 10:
                    eprint(debug, p)
                for _, child in p.walk():
                    if child not in primitive2index: continue
                    u[primitive2index[child]] += 1.0
        u = np.array(u)/n
        if debug:
            eprint(f"Got {n} samples. Feature vector:\n{u}")
            eprint(f"Likely used primitives: {[p for p,i in primitive2index.items() if u[i] > 0.5]}")
            eprint(f"Likely used primitive indices: {[i for p,i in primitive2index.items() if u[i] > 0.5]}")
        return u

    def featureVector(self, _=None, requests=None, onlyInventions=True, normalize=True):
        """
        Returns the probabilities licensed by the type system.
        This is like the grammar productions, but with irrelevant junk removed.
        Its intended use case is for clustering; it should be strictly better than the raw transition matrix.
        """
        if requests is None:
            if self.continuationType: requests = {self.continuationType}
            elif any( 'REAL' == str(p) for p in self.primitives ): requests = set()
            elif any( 'STRING' == str(p) for p in self.primitives ): requests = {tlist(tcharacter)}
            else: requests = set()
        requests = {r.returns() for r in requests}
        features = []
        logWeights = []
        for l,t,p in sorted(self.noParent.productions,
                            key=lambda z: str(z[2])):
            if onlyInventions and not p.isInvented: continue
            if any( canUnify(r, t.returns()) for r in requests ) or len(requests) == 0:
                logWeights.append(l)
        features.append(logWeights)
        for parent in sorted(self.primitives, key=str):
            if onlyInventions and not parent.isInvented: continue
            if parent not in self.library: continue
            argumentTypes = parent.infer().functionArguments()
            for j,g in enumerate(self.library[parent]):
                argumentType = argumentTypes[j]
                logWeights = []
                for l,t,p in sorted(g.productions,
                                    key=lambda z: str(z[2])):
                    if onlyInventions and not p.isInvented: continue
                    if canUnify(argumentType.returns(), t.returns()):
                        logWeights.append(l)
                features.append(logWeights)

        if normalize:
            features = [ [math.exp(w - z) for w in lw ]
                         for lw in features
                         if lw
                         for z in [lse(lw)] ]
        import numpy as np
        return np.array([f
                         for lw in features
                         for f in lw])

    def enumeration(self,context,environment,request,upperBound,
                    parent=None, parentIndex=None,
                    maximumDepth=20,
                    lowerBound=0.):
        '''Enumerates all programs whose MDL satisfies: lowerBound <= MDL < upperBound'''
        if upperBound < 0 or maximumDepth == 1:
            return

        if request.isArrow():
            v = request.arguments[0]
            for l, newContext, b in self.enumeration(context, [v] + environment,
                                                     request.arguments[1],
                                                     parent=parent, parentIndex=parentIndex,
                                                     upperBound=upperBound,
                                                     lowerBound=lowerBound,
                                                     maximumDepth=maximumDepth):
                yield l, newContext, Abstraction(b)
        else:
            if parent is None: g = self.noParent
            elif parent.isIndex: g = self.variableParent
            else: g = self.library[parent][parentIndex]

            candidates = g.buildCandidates(request, context, environment,
                                           normalize=True)

            for l, t, p, newContext in candidates:
                mdl = -l
                if not (mdl < upperBound):
                    continue

                xs = t.functionArguments()
                for aL, aK, application in\
                    self.enumerateApplication(newContext, environment, p, xs,
                                              parent=p,
                                              upperBound=upperBound + l,
                                              lowerBound=lowerBound + l,
                                              maximumDepth=maximumDepth - 1):
                    yield aL + l, aK, application

    def enumerateApplication(self, context, environment,
                             function, argumentRequests,
                             # Upper bound on the description length of all of
                             # the arguments
                             upperBound,
                             # Lower bound on the description length of all of
                             # the arguments
                             lowerBound=0.,
                             maximumDepth=20,
                             parent=None, 
                             originalFunction=None,
                             argumentIndex=0):
        assert parent is not None
        if upperBound < 0. or maximumDepth == 1:
            return
        if originalFunction is None:
            originalFunction = function

        if argumentRequests == []:
            if lowerBound <= 0. and 0. < upperBound:
                yield 0., context, function
            else:
                return
        else:
            argRequest = argumentRequests[0].apply(context)
            laterRequests = argumentRequests[1:]
            for argL, newContext, arg in self.enumeration(context, environment, argRequest,
                                                          parent=parent, parentIndex=argumentIndex,
                                                          upperBound=upperBound,
                                                          lowerBound=0.,
                                                          maximumDepth=maximumDepth):
                if violatesSymmetry(originalFunction, arg, argumentIndex):
                    continue

                newFunction = Application(function, arg)
                for resultL, resultK, result in self.enumerateApplication(newContext, environment, newFunction,
                                                                          laterRequests,
                                                                          parent=parent,
                                                                          upperBound=upperBound + argL,
                                                                          lowerBound=lowerBound + argL,
                                                                          maximumDepth=maximumDepth,
                                                                          originalFunction=originalFunction,
                                                                          argumentIndex=argumentIndex + 1):
                    yield resultL + argL, resultK, result
                
        


def violatesSymmetry(f, x, argumentIndex):
    if not f.isPrimitive:
        return False
    while x.isApplication:
        x = x.f
    if not x.isPrimitive:
        return False
    f = f.name
    x = x.name
    if f == "car":
        return x == "cons" or x == "empty"
    if f == "cdr":
        return x == "cons" or x == "empty"
    if f == "+":
        return x == "0" or (argumentIndex == 1 and x == "+")
    if f == "-":
        return argumentIndex == 1 and x == "0"
    if f == "empty?":
        return x == "cons" or x == "empty"
    if f == "zero?":
        return x == "0" or x == "1"
    if f == "index" or f == "map" or f == "zip":
        return x == "empty"
    if f == "range":
        return x == "0"
    if f == "fold":
        return argumentIndex == 1 and x == "empty"
    return False

def batchLikelihood(jobs):
    """Takes as input a set of (program, request, grammar) and returns a dictionary mapping each of these to its likelihood under the grammar"""
    superGrammar = Grammar.uniform(list({p for _1,_2,g in jobs for p in g.primitives}),
                                   continuationType=list(jobs)[0][-1].continuationType)
    programsAndRequests = {(program, request)
                           for program, request, grammar in jobs}
    with timing(f"Calculated {len(programsAndRequests)} likelihood summaries"):
        summary = {(program, request): superGrammar.closedLikelihoodSummary(request, program)
                   for program, request in programsAndRequests}
    with timing(f"Calculated log likelihoods from summaries"):
        response = {}
        for program, request, grammar in jobs:
            fast = summary[(program, request)].logLikelihood_overlyGeneral(grammar)
            if False: # debugging
                slow = grammar.logLikelihood(request, program)
                print(program)
                eprint(grammar.closedLikelihoodSummary(request, program))
                eprint(superGrammar.closedLikelihoodSummary(request, program))
                print()
                assert abs(fast - slow) < 0.0001
            response[(program, request, grammar)] = fast
    return response

if __name__ == "__main__":
    from dreamcoder.domains.arithmetic.arithmeticPrimitives import *
    g = ContextualGrammar.fromGrammar(Grammar.uniform([k0,k1,addition, subtraction]))
    g = g.randomWeights(lambda *a: random.random())
    #p = Program.parse("(lambda (+ 1 $0))")
    request = arrow(tint,tint)
    for ll,_,p in g.enumeration(Context.EMPTY,[],request,
                               12.):
        ll_ = g.logLikelihood(request,p)
        print(ll,p,ll_)
        d = abs(ll - ll_)
        assert d < 0.0001
