from datetime import datetime

import numpy as np
from scipy.signal import fftconvolve

from rayuela.base.semiring import Semiring
from numba import jit


@jit(nopython=True)
def fast_multiply(a, b, N):
    result = np.zeros(N)
    for i in range(N):
        for j in range(i + 1):
            result[i] += a[j] * b[i - j]
    return result


#@jit(nopython=True)
def star_jit(v, vz_star):
    N = len(v)
    # TODO: make size N
    vv = np.zeros(N + 1)
    one = np.zeros(N + 1)
    one[0] = 1

    vv[0] = vz_star
   
    for i in range(1, N): 
        C = one[i]
        
        for n in range(1, i + 1):
            # vz_star has been calculated < i
            C += v[n] * vv[i - n]      
        
        vv[i] = vz_star * C
    return vv


class OccurrenceWeight(Semiring):
    """This class represents an element in a semiring that tracks
    the number of times a given symbol has been accepted in a WFSA.

    It lifts a probabilistic WFA to a 'OuccurrenceWeight', maintaining the weight
    but also accruing the number of times a weight has been seen by means
    of the index of the non-zero element in the second component of the
    tuple.
    """
    one = None
    zero = None

    def __init__(self, x) -> None:
        self.x = list(x)
        self.N = self.x[1].shape[0] - 1

    @property
    def value(self):
        # used by rayuela as sampling probability
        return self.x[0]

    @classmethod
    def get_one(cls, N):
        one = OccurrenceWeight((1.0, np.zeros(N + 1)))
        one.x[1][0] = 1.0
        return one

    @classmethod
    def get_zero(cls, N):
        return OccurrenceWeight((0.0, np.zeros(N + 1)))

    def __str__(self) -> str:
        rounded = self.__round__()
        return str(rounded[0]) + ", " + str(rounded[1])

    def __repr__(self) -> str:
        return self.__str__()

    def __unicode__(self) -> str:
        return self.__str__()

    def __add__(self, y):
        return OccurrenceWeight((self.x[0] + y.x[0], self.x[1] + y.x[1]))

    def __neg__(self):
        return OccurrenceWeight((-self.x[0], -self.x[1]))

    def __sub__(self, y):
        return OccurrenceWeight((self.x[0] - y.x[0], self.x[1] - y.x[1]))

    def __abs__(self):
        return OccurrenceWeight((abs(self.x[0]), abs(self.x[1])))

    def __lmul__(self, y):
        return self.__mul__(y)

    def __rmul__(self, y):
        return self.__mul__(y)

    def old__mul__(self, y):
        if y is None:
            y = 0

        if isinstance(y, (int, float)):
            return OccurrenceWeight((self.x[0] * y, self.x[1] * y))

        zw = self.x[0] * y.x[0]

        N = self.x[1].shape[0]
        # scale up for stability
        a = self.x[1]
        b = y.x[1]
          
        max_val = max(np.max(np.abs(a)), np.max(np.abs(b)))
        scale_factor = 1.0 if max_val == 0 else 1.0 / max_val
        
        a_scaled = a * scale_factor
        b_scaled = b * scale_factor
        result_scaled = fftconvolve(a_scaled, b_scaled, mode='full')[:N]
        zv = result_scaled / (scale_factor ** 2)

        # Clipping to avoid overflow issues
        zw = np.clip(zw, 0, 1e100)
        zv = np.clip(zv, 0, 1e100)
        return OccurrenceWeight((zw, zv))

    def __mul__(self, y):
        if y is None:
            y = 0

        if isinstance(y, (int, float)):
            return OccurrenceWeight((self.x[0] * y, self.x[1] * y))

        zw = self.x[0] * y.x[0]

        N = self.x[1].shape[0]
        if N < 1000:  # Threshold for using fast_multiply vs fftconvolve
            zv = fast_multiply(self.x[1], y.x[1], N)
        else:
            zv = fftconvolve(self.x[1], y.x[1], mode='full')[:N]

        zw = np.clip(zw, 0, 1e100)
        zv = np.clip(zv, 0, 1e100)
        return OccurrenceWeight((zw, zv))

    def star(self):
        N = self.N
        vstar = self.get_one(N).copy()
        # the base weight
        vstar.x[0] = 1.0 / (1.0 - self.x[0]) 
        # (v*)_0 = (v_0)*
        vz_star = 1 / (1 - self.x[1][0])
        vstar.x[1] = star_jit(self.x[1], vz_star=vz_star)
        return vstar

    def __truediv__(self, y):
        if isinstance(y, (int, float)):
            return OccurrenceWeight((self.x[0] / y, self.x[1] / y))
        return OccurrenceWeight((self.x[0] / y.x[0], self.x[1] / y.x[0]))

    def __eq__(self, other) -> bool:
        return self.x[0] == other.x[0] and np.array_equal(self.x[1], other.x[1])

    def __round__(self, precission=3):
        return (round(self.x[0], precission), np.around(self.x[1], precission))

    def __invert__(self):
        if self.x[0] == 0:
            return OccurrenceWeight((0, self.x[1]))
        return OccurrenceWeight((1.0 / self.x[0], self.x[1]))

    def __lt__(self, other):
        if isinstance(other, (int, float)):
            return self.x[0] < other
        return self.x[0] < other.x[0]

    def __gt__(self, other):
        if isinstance(other, (int, float)):
            return self.x[0] > other
        return self.x[0] > other.x[0]

    def __le__(self, other):
        if isinstance(other, (int, float)):
            return self.x[0] <= other
        return self.x[0] <= other.x[0]

    def __ge__(self, other):
        if isinstance(other, (int, float)):
            return self.x[0] >= other
        return self.x[0] >= other.x[0]

    @classmethod
    def zeros(cls, N, M=None) -> np.ndarray:
        if M is not None:
            nm = []
            # If there is an issue here, check if N should be N+1
            for _ in range(N):
                nm.append([cls.zero for _ in range(M)])
            return np.array(nm)
        zeros = [cls.zero for _ in range(N)]
        return np.array(zeros)

    @classmethod
    def lift_weight(cls, target, w, N, symbol=None):
        v = np.zeros(N + 1)

        if hasattr(w, "value"):
            w = w.value
            if not isinstance(w, float):
                w = w[0]

        if symbol is None:
            v[0] = w
        elif target == symbol.value:
            v[1] = w
        else:
            v[0] = w

        return OccurrenceWeight((w, v))

    def copy(self):
        return OccurrenceWeight((self.x[0], self.x[1].copy()))

    def __hash__(self):
        return hash(self.x[0]) + hash(tuple(self.x[1]))