import numpy as np
import torch


class QModel:
    def __init__(self, observation_space, action_space):
        self.observation_space = observation_space
        self.action_space = action_space
        self.num_states = observation_space.n
        self.num_actions = action_space.n

    def __call__(self, state, action=None):
        return self.get_values(state, action)

    def update(self, s, a, r, sp, ap, done, gamma=0.99):
        return {}
    
    def secondary_update(self, s, a, r, sp, ap, done, gamma=0.99):
        return {}
    
    def update_reweight(self):
        pass

    def get_values(self, state, action=None):
        pass


class TabularQModel(QModel):
    def __init__(self, observation_space, action_space, learning_rate=1e-1):
        super().__init__(observation_space, action_space)
        self.table = np.zeros((self.num_states * self.num_actions), np.float32)
        self.learning_rate = learning_rate

    def update(self, s, a, r, sp, ap, done, gamma=0.99):
        target_v = r + gamma * (1 - done) * self.get_values(sp)[np.arange(sp.shape[0]), ap]
        idx = s * self.num_actions + a
        current_value = self.table[idx]
        self.table[idx] += (target_v - current_value) * self.learning_rate / s.shape[0]

        q_loss = ((target_v - current_value) ** 2).mean()
        return {'q_loss': q_loss}

    def get_values(self, s, a=None):
        if a is None:
            return self.table.reshape(self.num_states, self.num_actions)[s]
        else:
            return self.table[s * self.num_actions + a]


class LinearQModel(QModel):
    def __init__(self, observation_space, action_space, Phi, learning_rate=1e-1, reweight=None,
                 reweight_min=-np.inf, reweight_max=np.inf,
                 ensemble_size=1, w_init=None, soft_target_update_rate=1):
        super().__init__(observation_space, action_space)

        assert isinstance(Phi, torch.Tensor), "Phi must be a torch tensor"
        assert np.isclose(torch.norm(Phi, dim=-1).detach().cpu().numpy(), 1).all(), "Phi must be normalized"
        self.Phi = Phi
        self.feature_dim = Phi.shape[1]
        if w_init is None:
            w_init = 1e-1 * np.random.uniform(-1, 1, (ensemble_size, self.feature_dim))
        self.w = torch.tensor(w_init, dtype=torch.float32, requires_grad=True, device=Phi.device)
        self.w_target = torch.tensor(w_init, dtype=torch.float32, requires_grad=False, device=Phi.device)

        self.learning_rate = learning_rate
        self.soft_target_update_rate = soft_target_update_rate

        if reweight is None:
            self.reweight = torch.ones(self.num_states * self.num_actions, dtype=torch.float32, device=Phi.device)
        else:
            self.reweight = reweight
        self.reweight_min = reweight_min
        self.reweight_max = reweight_max
        
        self.num_updates = 0
        self.optim = torch.optim.Adam([self.w], lr=learning_rate)

    def update(self, s, a, r, sp, ap, done, gamma=0.99):
        x = s * self.num_actions + a
        current_value = (self.Phi @ self.w.T)[x]
        reweight = self.reweight[x]
        reweight /= reweight.mean()
        reweight = torch.clamp(reweight, self.reweight_min, self.reweight_max)

        # change_Q contains the amount by which current_value must be changed.
        xp = sp * self.num_actions + ap
        target_v = r + gamma * (1 - done.float()) * (self.Phi[xp] @ self.w_target.T).min(-1)[0]
        diff = target_v[:, None].detach() - current_value
        q_loss = (reweight[:, None].detach() * (diff ** 2)).mean()

        self.optim.zero_grad()
        q_loss.backward()
        self.optim.step()

        self.w_target = self.w_target * (1 - self.soft_target_update_rate) + self.w * self.soft_target_update_rate
        self.num_updates += 1

        return {'q_loss': q_loss.detach().cpu().numpy()}

    def get_values(self, s, a=None, reduction='mean'):
        is_torch = isinstance(s, torch.Tensor)
        if not is_torch:
            s = torch.tensor(s, dtype=torch.long, device=self.Phi.device)
        
        q = (self.Phi @ self.w.T).reshape(self.num_states, self.num_actions, -1)
        if a is None:
            q = q[s, :, :]
        else:
            q = q[s, a, :]
        
        if reduction == 'mean':
            q = q.mean(-1)
        elif reduction == 'min':
            q = q.min(-1)[0]
        else:
            raise ValueError(f"Unknown reduction: {reduction}")
        
        if not is_torch:
            q = q.detach().cpu().numpy()
        return q


class CQL(LinearQModel):
    def __init__(self, observation_space, action_space, Phi, learning_rate=1e-1, reweight=None,
                 ensemble_size=1, w_init=None, soft_target_update_rate=1, alpha_prime=1.0):
        super().__init__(observation_space, action_space, Phi, learning_rate=learning_rate, reweight=reweight,
                         ensemble_size=ensemble_size, w_init=w_init, soft_target_update_rate=soft_target_update_rate)
        self.alpha_prime = alpha_prime
        
        if reweight is not None:
            raise NotImplementedError("reweighting not implemented for CQL")
    
    def update(self, s, a, r, sp, ap, done, gamma=0.99):
        x = s * self.num_actions + a
        current_value = (self.Phi @ self.w.T)[x]
        
        # change_Q contains the amount by which current_value must be changed.
        xp = sp * self.num_actions + ap
        target_v = r + gamma * (1 - done.float()) * (self.Phi[xp] @ self.w_target.T).min(-1)[0]
        diff = target_v[:, None].detach() - current_value
        q_loss = (diff ** 2).mean()
        
        # CQL loss
        new_a = torch.randint(self.num_actions, (s.shape[0],), device=self.Phi.device)
        new_x = s * self.num_actions + new_a
        cql_q = torch.log(torch.exp((self.Phi @ self.w.T)[new_x]).sum())
        cql_loss = self.alpha_prime * (cql_q - current_value).mean()
        q_loss += cql_loss
        
        self.optim.zero_grad()
        q_loss.backward()
        self.optim.step()
        
        self.w_target = self.w_target * (1 - self.soft_target_update_rate) + self.w * self.soft_target_update_rate
        self.num_updates += 1
        
        return {'q_loss': q_loss.detach().cpu().numpy()}
        

class POPQ(LinearQModel):
    def __init__(self, observation_space, action_space, Phi, g_model, done_vec, w_init=None,
                 rank=4, q_lr=1e-1, dual_lr=1e-1, pop_margin=0.0, pop_gamma=1.0,
                 num_pop_updates=int(1e2), pop_weights_update_freq=1e2,
                 reweight_min=-np.inf, reweight_max=np.inf,
                 grad_clip=np.inf):
        super().__init__(observation_space, action_space, Phi, learning_rate=q_lr, w_init=w_init,
                 reweight_min=reweight_min, reweight_max=reweight_max)
        self.pop_margin = pop_margin
        self.pop_gamma = pop_gamma
        
        state, action = np.meshgrid(np.arange(self.num_states), np.arange(self.num_actions), indexing="ij")
        self.state_vec = torch.tensor(state.flatten(), dtype=torch.long, device=self.Phi.device)
        self.action_vec = torch.tensor(action.flatten(), dtype=torch.long, device=self.Phi.device)
        self.done_vec = torch.tensor(done_vec, dtype=torch.float32, device=self.Phi.device)

        self.rank = rank
        a_init = 1e-1 * np.random.uniform(-1, 1, (self.feature_dim, rank))
        b_init = 1e-1 * np.random.uniform(-1, 1, (self.feature_dim, rank))
        self.a = torch.tensor(a_init, dtype=torch.float32, requires_grad=True, device=Phi.device)
        self.a_mag = torch.tensor(1.0, dtype=torch.float32, requires_grad=True, device=Phi.device)
        self.b = torch.tensor(b_init, dtype=torch.float32, requires_grad=True, device=Phi.device)
        self.b_mag = torch.tensor(1.0, dtype=torch.float32, requires_grad=True, device=Phi.device)
        self.grad_clip = grad_clip

        self.dual_optim = torch.optim.Adam([self.a, self.b, self.a_mag, self.b_mag], lr=dual_lr)
        # self.dual_optim = torch.optim.Adam([self.a, self.b], lr=dual_lr)

        self.g_model = g_model
        self.pop_weights_update_freq = pop_weights_update_freq
        self.num_pop_updates = num_pop_updates
        
        self.reweight_needs_update = True
    
    def update(self, s, a, r, sp, ap, done, gamma=0.99):
        if self.reweight_needs_update:
            self.update_reweight()
        
        return super().update(s, a, r, sp, ap, done, gamma)
    
    def _pop_values(self, state, action, done, state_prime=None, action_prime=None):
        x = state * self.num_actions + action
        
        a_2_norm = torch.svd(self.a, compute_uv=False)[1].max().detach()
        b_2_norm = torch.svd(self.b, compute_uv=False)[1].max().detach()
        a = self.a_mag * (self.a / a_2_norm)
        b = self.b_mag * (self.b / b_2_norm)
        m_a = (self.Phi @ a)[x]
        m_b = (self.Phi @ b)[x]
        # m_b = self.b_mag * (self.Phi @ (self.b / b_2_norm))[x]
        # m_a = self.a_mag * (self.Phi @ (self.a / a_2_norm))[x]
        # m_a = (self.Phi @ self.a)[x]
        # m_b = (self.Phi @ self.b)[x]
        g = self.g_model(state, action)
        
        reweight = torch.exp(
            (m_a ** 2 + m_b ** 2).sum(-1) + 2 * (1 - done.float()) * self.pop_gamma * self.a_mag * self.b_mag * g
            - self.pop_margin * (a ** 2 + b ** 2).sum()
        )
        # reweight = torch.exp(
        #     (m_a ** 2 + m_b ** 2).sum(-1) + 2 * (1 - done.float()) * g
        #     - self.pop_margin * (self.a ** 2 + self.b ** 2).sum()
        # )
        
        if state_prime is None and action_prime is None:
            return reweight
        elif state_prime is not None and action_prime is not None:
            xp = state_prime * self.num_actions + action_prime
            m_a_prime = (self.Phi @ a)[xp]
            # m_a_prime = self.a_mag * (self.Phi @ (self.a / a_2_norm))[xp]
            # m_a_prime = (self.Phi @ self.a)[xp]
            g_target = (m_b * m_a_prime).sum(-1) / (self.a_mag * self.b_mag + 1e-6)
            # g_target = (m_b * m_a_prime).sum(-1)
            dual_loss = (reweight.detach() *
                         ((m_a ** 2 + m_b ** 2).sum(-1)
                          + 2 * (1 - done.float()) * self.pop_gamma * (m_b * m_a_prime).sum(-1)
                          - self.pop_margin * (a ** 2 + b ** 2).sum())
                         ).mean()
            
            return reweight, g_target, dual_loss
        else:
            raise ValueError("state_prime and action_prime must both be None or both be not None")

    def secondary_update(self, state, action, r, state_prime, action_prime, done, gamma=0.99):
        reweight, g_target, dual_loss = self._pop_values(state, action, done, state_prime, action_prime)
        
        if np.isnan(dual_loss.detach().cpu().numpy()).any():
            raise ValueError("NaN dual loss")
        
        self.dual_optim.zero_grad()
        dual_loss.backward()
        self.dual_optim.step()

        g_loss = self.g_model.update(state, action, g_target, done, reweight=reweight)
        
        self.reweight_needs_update = True

        return {
            'dual_loss': dual_loss.detach().cpu().numpy(),
            'g_loss': g_loss,
            'a_2_norm': self.a_mag.detach().cpu().numpy(),
            'b_2_norm': self.b_mag.detach().cpu().numpy(),
        }
    
    def update_reweight(self):
        reweight = self._pop_values(self.state_vec, self.action_vec, self.done_vec)
        self.reweight = reweight
        
        self.reweight_needs_update = False


class GModel:
    def __init__(self, observation_space, action_space):
        self.observation_space = observation_space
        self.action_space = action_space
        self.num_states = observation_space.n
        self.num_actions = action_space.n

    def __call__(self, state, action):
        return self.get_value(state, action)

    def update(self, s, a, target_g, done, reweight=None):
        pass

    def get_value(self, state, action):
        pass


class LinearGModel(GModel):
    def __init__(self, observation_space, action_space, Phi, lr=1e-2):
        super().__init__(observation_space, action_space)

        self.Phi = Phi
        self.feature_dim = Phi.shape[1]
        w_init = 1e-1 * np.random.uniform(-1, 1, (1, self.feature_dim))
        self.w = torch.tensor(w_init, dtype=torch.float32, requires_grad=True, device=Phi.device)
        self.optim = torch.optim.Adam([self.w], lr=lr)

    def update(self, s, a, target_g, done, reweight=None):
        g = self.get_value(s, a)
        loss = ((1 - done.float()) * (target_g.detach() - g) ** 2).mean()

        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

        return loss.detach().cpu().numpy()

    def get_value(self, s, a):
        x = s * self.num_actions + a
        return (self.Phi @ self.w.T)[x].flatten()
