import time
import math
import logging
from logging.handlers import MemoryHandler
from collections import OrderedDict
from typing import Optional

import torch


class Cache:

    # can only have one instance
    _instance = None
    def __new__(cls, capacity: int):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.cache = OrderedDict()
        self.num_hits = 0

    def __getitem__(self, key: tuple) -> Optional[tuple]:
        if key in self.cache:
            self.num_hits += 1
            self.cache.move_to_end(key)
            return self.cache[key]
        else:
            return None

    def __setitem__(self, key: tuple, value: tuple) -> None:
        while len(self.cache) >= self.capacity:
            self.cache.popitem(last=False)
        self.cache[key] = value

    def __contains__(self, key):
        raise

    @classmethod
    def gen_key(
        cls,
        head_vars: list[list[int]],  # num_rules x sum_head_arities
        body_vars: list[list[list[int]]],  # num_rules x max_occurrence_in_body x sum_body_arities
        head_atom: torch.Tensor,  # num_rules
        body_atom: torch.Tensor  # num_rules x max_occurrence_in_body x num_predicates
    ) -> tuple:
        return (
            tuple(tuple(v) for v in head_vars),
            tuple(tuple(tuple(v2) for v2 in v1) for v1 in body_vars),
            tuple(head_atom),
            tuple(tuple(tuple(b2) for b2 in b1) for b1 in body_atom),
        )


class ExponentialScheduler:

    def __init__(self, init_value, final_value, restart_period):
        self.init_value = init_value
        self.final_value = final_value
        self.restart_period = restart_period

        self.rate = math.log(final_value / init_value)

    def get_value(self, epoch: int):
        """
            epoch: current training epoch
        """

        frac = (epoch % self.restart_period) / self.restart_period
        return self.init_value * math.exp(
            self.rate * frac
        )
    

# class ExponentialAnnealingRestartScheduler:

#     def __init__(self, init_value, final_value, restart_period):
#         self.init_value = init_value
#         self.final_value = final_value
#         self.restart_period = restart_period

#         self.a = math.log10(self.init_value)
#         self.b = math.log10(self.final_value) - math.log10(self.init_value)

#     def get_value(self, epoch: int):
#         """
#             epoch: current training epoch
#         """

#         frac = (epoch % self.restart_period) / self.restart_period
#         return math.pow(
#             10,
#             self.a + self.b * frac  # = math.log10(self.init_value) + (math.log10(self.final_value) - math.log10(self.init_value)) * frac
#         )  # = self.init_value * math.pow(self.final_value / self.init_value, frac)


class CosineAnnealingRestartScheduler:

    def __init__(self, init_value, final_value, restart_period):
        self.init_value = init_value
        self.final_value = final_value
        self.restart_period = restart_period

    def get_value(self, epoch: int):
        """
            epoch: current training epoch
        """

        frac = (epoch % self.restart_period) / self.restart_period
        return self.final_value + (self.init_value - self.final_value) * (1 + math.cos(frac * math.pi)) / 2
    

class LinearAnnealingRestartScheduler:

    def __init__(self, init_value, final_value, restart_period):
        self.init_value = init_value
        self.final_value = final_value
        self.restart_period = restart_period

    def get_value(self, epoch: int):
        """
            epoch: current training epoch
        """

        frac = (epoch % self.restart_period) / self.restart_period
        return self.init_value + (self.final_value - self.init_value) * frac


# for jit purpose
def deduplicate_list(lst: list[int]) -> list[int]:
    # return list(dict.fromkeys(lst))

    result: list[int] = []
    for el in lst:
        if el not in result:
            result.append(el)
    return result


# for jit purpose
def sorting_permutation(lst: list[int]) -> tuple[list[int], list[int]]:
    # permutation, sorted_elem = zip(*[a for a in sorted(enumerate(lst), key=lambda x: x[1])])

    sorted_elem, permutation = torch.sort(torch.tensor(lst))
    permutation: list[int] = permutation.tolist()
    sorted_elem: list[int] = sorted_elem.tolist()
    return permutation, sorted_elem


def apply_permutation(lst: list[int], perm: list[int]) -> list[int]:
    return [lst[i] for i in perm]


def inverse_permutation(p: list[int]) -> list[int]:
    inv = [0] * len(p)
    for i, val in enumerate(p):
        inv[val] = i
    return inv

def is_ascending(lst: list[int]) -> bool:
    return all(lst[i - 1] <= lst[i] for i in range(1, len(lst)))


def no_duplicate(lst: list[int]) -> bool:
    for i, elem in enumerate(lst):
        if elem in lst[i + 1: ]:
            return False
    return True


def format_unground_atom(predicate_name: str, vars_num: list[int]):
    return f"{predicate_name}({', '.join(f'X{v}' for v in vars_num)})"

def format_unground_rule(heads: list[str], bodys: list[str]):
    return [
        head + " ← " + " ∧ ".join(bodys)
        for head in heads
    ]


class PeriodicHandler(MemoryHandler):

    def __init__(
        self,
        filename: str,
        capacity: Optional[int] = None,
        period: Optional[int] = None,  # seconds
    ):
        file_handler = logging.FileHandler(filename, mode='a')
        file_handler.setFormatter(logging.Formatter("[%(levelname)s] [%(asctime)s] [%(name)s:%(lineno)d] %(message)s"))
        MemoryHandler.__init__(self, capacity, target=file_handler, flushOnClose=True)

        self.period = period
        self.last_flush_time = time.time()

    def shouldFlush(self, record):
        if self.capacity is None and self.period is None:
            return True

        now_time = time.time()

        if record.levelno >= self.flushLevel:
            self.last_flush_time = now_time
            return True
        
        if self.capacity is not None:
            if len(self.buffer) >= self.capacity:
                self.last_flush_time = now_time
                return True                
        
        if self.period is not None:
            if now_time - self.last_flush_time >= self.period:
                self.last_flush_time = now_time
                return True
        
        return False
