import numpy as np
import torch
from gym.spaces import Discrete
from torch.distributions import Categorical
from games.graphon_mfg import FiniteGraphonMeanFieldGame
from solver.omd_solver_array import normalize_at_last_axis
import pdb


class BeachGraphon(FiniteGraphonMeanFieldGame):
    """
    Models the Beach Bar Process.
    """

    def __init__(self, time_steps: int = 10, N_states=10,
                 noise_prob: float = 0.1, **kwargs):
        self.noise_prob = noise_prob
        self.N_states = N_states

        # States: 0 1 2 3 4 Bar 6 7 8 9
        # Actions: Left Stay Right
        def initial_state_distribution(x):
            return Categorical(probs=torch.tensor([1/N_states] * N_states))
        agent_observation_space = Discrete(N_states)
        agent_action_space = Discrete(3)
        super().__init__(agent_observation_space, agent_action_space, time_steps, initial_state_distribution)

    def transition_probs(self, t, x, u, mu):
        u_adjusted = (u - 1) % 10
        transition_probs = np.zeros(self.N_states)
        indices = [(x-1 + u_adjusted) % 10, (x + u_adjusted) % 10, (x+1 + u_adjusted) % 10]
        transition_probs[indices] = [self.noise_prob/2, (1-self.noise_prob), self.noise_prob/2]
        return transition_probs

    def reward(self, t, x, u, mu):
        r_x = -abs(x - self.N_states/2) / (self.N_states/2)
        r_a = -((u==0) + (u==2)) / (self.N_states/2)
        # pdb.set_trace()
        r_mu = -np.log(mu.mu_alphas[0].pmf(t)[x])
        return r_x + r_a + r_mu

class SquareBeach(BeachGraphon):
    def reward(self, t, x, u, mu):
        r_x = -abs(x - self.N_states/2) / (self.N_states/2)
        r_a = -((u==0) + (u==2)) / (self.N_states/2)
        r_mu = - mu.mu_alphas[0].pmf(t)[x] ** 2
        return r_x + r_a + r_mu
    
class RandomBeach(SquareBeach):
    def __init__(self,
                 time_steps: int = 3,
                 N_states=10,
                 noise_prob: float = 0,
                 **kwargs):
        super().__init__(time_steps=50, N_states=N_states, noise_prob=noise_prob, **kwargs)
        np.random.seed(0)
        # random_array = np.random.rand(time_steps, N_states, 3, N_states)
        # print(f'test random_array:{random_array[0,0,0,0]}')
        # self.transition_matrix = normalize_at_last_axis(random_array)

    def transition_probs(self, t, x, u, mu):
        return self.transition_matrix[t, x, u, :]
    def reward(self, t, x, u, mu):
        r_x = -abs(x - self.N_states/2) / (self.N_states/2)
        r_a = 0
        # if t == self.time_steps - 1 and x == 0 and u == 0:
        #     print(f'test mu:{mu.mu_alphas[0].pmf(t)[x]}')
        r_mu = - mu.mu_alphas[0].pmf(t)[x] ** 2
        return r_x + r_a + r_mu

class PowerBeach(RandomBeach):
    def reward(self, t, x, u, mu):
        r_x = -abs(x - self.N_states/2) / (self.N_states/2)
        r_mu = - mu.mu_alphas[0].pmf(t)[x] ** 101
        return r_x + r_mu
class ZeroSumBeach(RandomBeach):
    def __init__(self, N_states=10, *args, **kwargs):
        super().__init__(N_states=N_states, *args, **kwargs)
        self.time_steps = 3
        self.noise_prob = 0
        # pdb.set_trace()
        self.R, self.L = np.split(np.arange(self.N_states),2)
    

    def transition_probs(self, t, x, u, mu):
        u_adjusted = (u - 1) % self.N_states
        # pdb.set_trace()
        transition_probs = np.zeros(self.N_states)
        indices = [(x-1 + u_adjusted) % self.N_states, (x + u_adjusted) % self.N_states, (x+1 + u_adjusted) % self.N_states]
        transition_probs[indices] = [0, 1, 0]
        # print(f't={t}, x={x}, u={u}, indices={indices}, transition_probs:{transition_probs}')
        return transition_probs

    def indicator_R(self, x):
        if isinstance(x, tuple):
            x = x[-1]
        return x in self.R

    def indicator_L(self, x):
        if isinstance(x, tuple):
            x = x[-1]
        return x in self.L

    def reward(self, t, x, u, mu):
        r_x = -abs(x - self.N_states/2) / (self.N_states/2) * 0
        # r_mu = - mu.evaluate_integral(t, self.indicator_R) * self.indicator_L(x) 
        # r_mu += mu.evaluate_integral(t, self.indicator_L) * self.indicator_R(x)
        r_mu = (mu.evaluate_integral(t, self.indicator_R)-0.25) * self.indicator_L(x) 
        r_mu += (mu.evaluate_integral(t, self.indicator_L)-0.75) * self.indicator_R(x)
        return r_x + r_mu

class MDPBeach(BeachGraphon):
    def reward(self, t, x, u, mu):
        r_x = -abs(x - self.N_states/2) / (self.N_states/2)
        r_a = -((u==0) + (u==2)) / (self.N_states/2)
        return r_x + r_a

class time_step_only_one_Beach(MDPBeach):
    def __init__(self, N_states=10,
                 noise_prob: float = 0.1, **kwargs):
        super().__init__(time_steps=1, N_states=N_states, noise_prob=noise_prob, **kwargs)