"""
This module implements schedule dataclass.
"""

import functools
from fractions import Fraction

import numpy
import pydantic
from tqdm import tqdm

import problem
import setting
import space


class Schedule(pydantic.BaseModel):
    """A schedule generated by a defense plan."""

    data: problem.PatrollingProblem = pydantic.Field(exclude=True)

    schedule: list[space.State]

    def observation(self, t: int, observation_length: int = None):
        """Returns attacker's observation at time t."""
        if observation_length is None:
            observation_length = self.data.observation_length

        if t < observation_length or t >= len(self.schedule):
            raise IndexError(f'Observation time {t} is outside of schedule range.')

        return tuple(self.schedule[t - observation_length+1:t+1])

    def capture_probability(self, t: int, j: setting.Target):
        """Returns capture probability if attack at target j was initiated at time t."""

        if t + self.data.tau[j] >= len(self.schedule):
            raise IndexError(f'Capture probability time {t} is outside of schedule range.')

        return 1.0 - numpy.prod([1.0 - self.data.base.coverage[self.schedule[i]][j]
                                 for i in range(t, t + self.data.tau[j] + 1)])

    def reward(self, observation_length: int = None) \
            -> dict[tuple[problem.Observation, setting.Target], tuple[setting.Rational, int]]:
        """For each (observation, target) pair a total reward and a number of instances of the observation."""

        if observation_length is None:
            observation_length = self.data.observation_length

        instance_count: dict[tuple[problem.Observation, setting.Target], int] = {}
        """Observation frequency."""
        total_reward: dict[tuple[problem.Observation, setting.Target], int] = {}
        """Total reward."""

        for j in self.data.targets:
            for t in range(observation_length, len(self.schedule) - self.data.tau[j]):
                if t == 0 or self.observation(t, 1)[0][0] < 0:
                    continue
                observation = self.observation(t, observation_length)
                instance_count[observation, j] = instance_count.get((observation, j), 0) + 1
                total_reward[observation, j] = (total_reward.get((observation, j), 0) +
                                                self.capture_probability(t, j) * self.data.reward[j])

        return {
            (observation, j): (total_reward[observation, j], instance_count[observation, j])
            for observation, j in instance_count
        }
