import numpy as np
from abc import ABC, abstractmethod
from typing import Callable
from collections import Counter

class Generative_Model(ABC):
    @property
    @abstractmethod
    def is_mdp(self):
      pass
    
    @property
    @abstractmethod
    def states(self):
      pass
    
    @property
    @abstractmethod
    def rewards(self):
      pass
    
    @property
    @abstractmethod
    def sa_pairs(self):
      pass

    @property
    @abstractmethod
    def action_at_state(self):
      pass
    
    @property
    def r_max(self):
      return max(self.rewards)
    
    def generate_action_at_state(self):
        action_at_state = {}
        for sa in self.sa_pairs:
            temp = action_at_state.get(sa[0], [])
            temp.append(sa[1])
            action_at_state[sa[0]] = temp
        return action_at_state
      
    def to_frequency(self, samples, empirical_measure = {}):
       frequency = Counter(samples)
       for element, count in frequency.items():
          empirical_measure[element] = empirical_measure.get(element, 0) + count
       return empirical_measure
    
    def generate_empirical_distribution_s(self, sa, k):
        distribution_of_state = self.transition_map[sa]
        empirical_frequency_s = {}
        c = int(1e7)
        n_iteration = k // c
        remains = k % c

        for _ in range(n_iteration):
            state = np.random.choice(self.states[distribution_of_state > 0], c, p = distribution_of_state[distribution_of_state > 0])
            for s in state:
                empirical_frequency_s[s] = empirical_frequency_s.get(s, 0) + 1
        if remains > 0:
            state = np.random.choice(self.states[distribution_of_state > 0], remains, p = distribution_of_state[distribution_of_state > 0])
            for s in state:
                empirical_frequency_s[s] = empirical_frequency_s.get(s, 0) + 1
        empirical_distribution_s = {s: count/k for s, count in empirical_frequency_s.items()}
        state_order = sorted(self.states)
        empirical_distribution_array = np.array([empirical_distribution_s.get(s, 0) for s in state_order])
        return empirical_distribution_array
    
    def generate_empirical_distribution_r(self, sa, k):
        distribution_of_reward = self.reward_map[sa]
        empirical_frequency_r = {}
        c = int(1e7)
        n_iteration = k // c
        remains = k % c

        for _ in range(n_iteration):
            reward = np.random.choice(self.rewards[distribution_of_reward > 0], c, p = distribution_of_reward[distribution_of_reward > 0])
            for r in reward:
                empirical_frequency_r[r] = empirical_frequency_r.get(r, 0) + 1
        if remains > 0:
            reward = np.random.choice(self.rewards[distribution_of_reward > 0], remains, p = distribution_of_reward[distribution_of_reward > 0])
            for r in reward:
                empirical_frequency_r[r] = empirical_frequency_r.get(r, 0) + 1
        empirical_distribution_r = {r: count/k for r, count in empirical_frequency_r.items()}
        reward_order = sorted(self.rewards)
        empirical_distribution_array = np.array([empirical_distribution_r.get(r, 0) for r in reward_order])
        return empirical_distribution_array



    def generate_state(self, sa, k, empirical_measure = False):
        distribution_of_state = self.transition_map[sa]
        if not empirical_measure:
           state = np.random.choise(self.states[distribution_of_state > 0], k, p = distribution_of_state[distribution_of_state > 0])
           return np.array(state)
        else:
            c = int(1e7)
            n_iteration = int(k/c)
            remains = k - c*n_iteration
            empirical_frequency_s = {}
            for i in range(n_iteration):
                state = np.random.choice(self.states[distribution_of_state > 0], c, p = distribution_of_state[distribution_of_state > 0])
                empirical_frequency_s = self.to_frequency(state, empirical_frequency_s)
            state = np.random.choice(self.states[distribution_of_state > 0], remains, p = distribution_of_state[distribution_of_state > 0])
            empirical_frequency_s = self.to_frequency(state, empirical_frequency_s)
            del state
            s0 = []
            s1 = []
            for keys in empirical_frequency_s.keys():
               s0.append(keys)
               s1.append(empirical_frequency_s.get(keys)/k)
               sret = [np.array(s0), np.array(s1)]
            return sret
    
    def generate_reward(self, sa, k, empirical_measure = False):
        r_dist = self.reward_map[sa]
        if not empirical_measure:
            reward = np.random.choice(self.rewards[r_dist > 0], k, p = r_dist[r_dist > 0])
            return np.array(reward)
        else:
            c = int(1e7)
            n_iteration = int(k/c)
            remains = k - c*n_iteration
            empirical_frequency_r = {}
            for i in range(n_iteration):
                reward = np.random.choice(self.rewards[r_dist > 0], c, p = r_dist[r_dist > 0])
                empirical_frequency_r = self.to_frequency(reward, empirical_frequency_r)
            reward = np.random.choice(self.rewards[r_dist > 0], remains, p = r_dist[r_dist > 0])
            empirical_frequency_r = self.to_frequency(reward, empirical_frequency_r)
            del reward
            r0 = []
            r1 = []
            for keys in empirical_frequency_r.keys():
                r0.append(keys)
                r1.append(empirical_frequency_r.get(keys)/k)
                rret = [np.array(r0), np.array(r1)]
            return rret
        
    def get_sa_pairs(self):
        return self.sa_pairs
    
