import math

import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce, einsum, unpack
from einops.layers.torch import Rearrange


def get_regular_rotation(theta):
    s = torch.sin(theta)
    c = torch.cos(theta)
    rot = torch.stack([torch.stack([c, -s], dim=-1),
                       torch.stack([s, c], dim=-1)], dim=-2)
    return rot


class Evolver_so2(nn.Module):
    action_dim = 2
    slot_axis = 2

    @property
    def delta_dim(self):
        return self.ap.delta_dim

    @property
    def num_actions(self):
        return self.ap.num_actions
    
    def __init__(self, action_predictor, num_freqs, burnin, fix_lastslot=False):
        super().__init__()
        
        self.ap = action_predictor
        self.alignment = nn.Identity()   # Will be replaced when fine-tuning
        self.fix_lastslot = fix_lastslot
        
        b = 10000.0
        freqs = torch.exp(torch.arange(0., 2*num_freqs, 2) *
                          -(math.log(b) / (2*num_freqs)))
        self.register_buffer('freqs', freqs)
        
        self.A = nn.Parameter(torch.randn(self.num_actions))
        self.burnin = burnin

        self.actionable = nn.Sequential(
            # Rearrange('... (param freq h a) -> ... param freq h a', 
            Rearrange('... (freq param h a) -> ... freq param h a', 
                      param=self.num_actions, freq=num_freqs, a=self.action_dim),
        )
        self.deactionable = nn.Sequential(
            # Rearrange('... param freq h a -> ... (param freq h a)'),
            Rearrange('... freq param h a -> ... (freq param h a)'),
        )


    def time_evolve(self, latents, deltas):
        '''
        latents: [seq_len, batch, slot, freq, num_actions, hidden, action_dim]
                     or   [batch, slot, freq, num_actions, hidden, action_dim]
        deltas:  [seq_len, batch, slot, num_actions]

        Evolve latents by the exponential map with deltas:
            Z_{t+1} = exp(delta_t A) Z_t
        where A is the generator of the group action. We restrict A to be block diagonal of 2x2 matrices and i-th block has the form of 
            A_i = [0 -a_i
                   a_i 0]
        so that the exponential map is the direct sum of rotation matrices:
            exp(delta_t A_i) = [cos(a_i delta_t) -sin(a_i delta_t)
                                sin(a_i delta_t) cos(a_i delta_t)].
        '''
        deltas = deltas.squeeze(-1)

        # delta_a = deltas * F.softplus(self.A)
        delta_a = deltas * self.A
        # delta_a = einsum(delta_a, self.freqs, '..., freq -> ... freq')
        delta_a = einsum(delta_a, self.freqs, '... action, freq -> ... freq action')
        rotation_matrix = get_regular_rotation(delta_a)
        out_latents = einsum(latents, rotation_matrix, '... h in, ... in out-> ... h out')

        if self.fix_lastslot:
            if len(latents.shape) == 6:
                latents = repeat(latents, '... -> t ...', t=len(out_latents))
            # print(fixed_slot.shape, out_latents[:, :, -1:].shape)
            out_latents = torch.cat([
                out_latents[:, :, :-1], 
                latents[:, :, -1:],
            ], dim=self.slot_axis)
        
        return out_latents

    def forward(self, latents, actions=None):
        '''
        Given Z_0, ..., Z_T, predict Z_t0+1, ..., Z_T and Z_t0, ..., Z_T-1 with t0 burn-in frames.
        Z_0, ..., Z_t0 is used to compute deltas.

        latents: Z_0, ..., Z_T where Z_t is the latent representation at time t
        '''
        latents = self.actionable(latents)

        latents = self.alignment(latents)

        latents, deltas = self.ap(latents, actions)
        latents, deltas = latents[self.burnin:], deltas[self.burnin:]
        
        init_latent = latents[0]
        cum_deltas = deltas.cumsum(dim=0)
        pred_latents = self.time_evolve(init_latent, cum_deltas)

        # create indices like [0 0 2 2 4 4 ... 2(n-1) 2(n-1)] for n = len(deltas) // 2
        skipped_indices = ((torch.arange(len(deltas)) / 2).floor() * 2).int().to(deltas.device)

        deltas_skipped = deltas[skipped_indices]
        cum_deltas_skipped = deltas_skipped.cumsum(dim=0)
        pred_latents_skipped = self.time_evolve(init_latent, cum_deltas_skipped)

        last_latent = latents[-1]
        rev_cum_deltas = deltas - cum_deltas + cum_deltas[-1:]
        rev_pred_latents = self.time_evolve(last_latent, -rev_cum_deltas)

        return (
            self.deactionable(pred_latents), 
            self.deactionable(rev_pred_latents), 
            self.deactionable(pred_latents_skipped), 
            deltas
        )


class Evolver_sim2(Evolver_so2):
    def __init__(self, max_lambda=2.0, **kwargs):
        super().__init__(**kwargs)
        # del self.A
        self.B = nn.Parameter(torch.randn(self.num_actions))
        self.register_buffer('freqs_scale', (self.freqs > 0).to(int))
        self.max_lambda = max_lambda
    
    def time_evolve(self, latents, deltas):
        theta, lmbda = deltas.unbind(dim=-1)
        # lmbda = lmbda / math.sqrt(self.kernel_size)
        
        delta_a = theta * self.A
        # delta_a = theta 
        # delta_a = einsum(delta_a, self.freqs, '..., freq -> ... freq')
        delta_a = einsum(delta_a, self.freqs, '... action, freq -> ... freq action')
        delta_a = delta_a.float()   # create rot mat in fp32
        rotation_matrix = get_regular_rotation(delta_a)
        # print('evolver rot mat dtype:', rotation_matrix.dtype)
        
        delta_b = lmbda * self.B
        # delta_b = lmbda 
        # delta_b = einsum(delta_b, self.freqs_scale, '..., freq -> ... freq')
        delta_b = einsum(delta_b, self.freqs_scale, '... action, freq -> ... freq action')
        delta_b = torch.clamp(delta_b, None, self.max_lambda)
        delta_b = delta_b.float()   # create rot mat in fp32
        scale = torch.exp(delta_b)

        out_latents = einsum(latents, rotation_matrix, scale, '... h in, ... in out, ... -> ... h out')
        # print('evolver evolve:', latents.shape, rotation_matrix.shape, scale.shape, out_latents.shape)

        return out_latents


class Evolver_scale(Evolver_so2):
    def __init__(self, max_lambda=2.0, **kwargs):
        super().__init__(**kwargs)
        self.register_buffer('freqs_scale', (self.freqs > 0).to(int))
        self.max_lambda = max_lambda
    
    def time_evolve(self, latents, deltas):
        lmbda = deltas.squeeze(-1)
        
        delta_b = lmbda * self.A
        delta_b = einsum(delta_b, self.freqs_scale, '... action, freq -> ... freq action')
        delta_b = torch.clamp(delta_b, None, self.max_lambda)
        delta_b = delta_b.float()   # create rot mat in fp32
        scale = torch.exp(delta_b)

        out_latents = einsum(latents, scale, '... h in, ... -> ... h in')

        return out_latents


# def cumprod_matrix(U):
#     # U : tensor of dimension (d ... n n)
#     U_cum = [U[0]]
#     for u in U[1:]:
#         U_cum.append(torch.matmul(U_cum[-1], u))
#     U_cum = torch.stack(U_cum)  # tensor of dimension (d ... n n)
#     return U_cum

        

class Evolver_se2(Evolver_so2):
    action_dim = 3

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.B = nn.Parameter(torch.randn(self.to_delta.num_actions))
    
    def time_evolve(self, latents, deltas):
        time_step = F.softplus(self.A)

        # deltas = [thetas, betas] where thetas in R for rotation, betas in R^2 for translation
        thetas, betas_x, betas_y = deltas.unbind(dim=-1)
        betas = torch.stack([betas_x, betas_y], dim=-1)

        # compute Phi(t) = R(∫_0^t theta_tau dtau)
        cum_thetas = thetas.cumsum(dim=0)
        delta_a = einsum(cum_thetas, self.freqs, time_step, 
                         '... anum, freq, anum -> ... anum freq')
        Phi_t = get_regular_rotation(delta_a)
        out_latents = einsum(latents, Phi_t, '... h in, ... in out -> ... h out')
        
        # compute Phi(t, s) = R(∫_s^t theta_tau dtau)
        phi = rearrange(delta_a, 't ... -> t () ...') \
            - rearrange(delta_a, 't ... -> () t ... ') 
        Phi = get_regular_rotation(phi)
        # apply triangular mask to rot_mat so that rot_mat[i, j] = 0 if i < j
        mask = torch.tril(torch.ones_like(Phi), diagonal=0)
        Phi = Phi * mask

        # compute translation b = ∫_t0^t  Phi(t, s) b(s) ds
        time_step = F.softplus(self.B)
        delta_b = einsum(betas, self.freqs, time_step, 
                         '... anum adim, freq, anum -> ... anum freq adim')
        translation = einsum(delta_b, Phi, '... in, t ... in out -> t ... out')
        translation = reduce(translation, 't s ... adim -> t ... () adim', 'sum')

        return out_latents + translation
    
    def forward(self, latents):
        # deltas_rot, deltas_trans = self.to_delta(latents).unbind(dim=-1)
        deltas = self.to_delta(latents)

        init_latent = latents[0]
        pred_latents = self.time_evolve(init_latent, deltas)

        last_latent = latents[-1]
        rev_pred_latents = self.time_evolve(last_latent, -deltas.flip(dims=(0,)))

        return pred_latents, rev_pred_latents, deltas