import torch
import torch.nn as nn

class RFFEmbedding(nn.Module):
    def __init__(self, position_size, basis_dim, dtype=torch.float32, freq_std=8.0):
        super().__init__()
        self.position_size = position_size
        
        torch.manual_seed(seed=42)
        self.frequencies = torch.randn(self.position_size, basis_dim, dtype=dtype) * freq_std

    def forward(self, position: torch.Tensor) -> torch.Tensor:
        device = position.device
        proj = torch.matmul(position, self.frequencies.to(device))
        fourier_features = torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)
        return fourier_features

class RPFMLP(nn.Module):
    def __init__(self, hidden_dim, fourier_dim, freq_std):
        super().__init__()
        self.freq_std = freq_std
        self.time_embedding = nn.Linear(1, fourier_dim)
        self.noise_embedding = nn.Linear(3, fourier_dim)
        self.rff_embedding = RFFEmbedding(position_size=2, basis_dim=fourier_dim, freq_std=freq_std)
        self.mlp = nn.Sequential(
            nn.Linear(4*fourier_dim, hidden_dim),
            nn.ReLU(),

            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),

            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),

            nn.Linear(hidden_dim, 3),
        )
    def forward(self, t, xt, position):
        t_embed = self.time_embedding(t.unsqueeze(-1))
        noise_embed = self.noise_embedding(xt)
        rff_embed = self.rff_embedding(position)

        features = torch.concat([t_embed, noise_embed, rff_embed], dim=-1)
        return self.mlp(features)
    

class FlowMatching(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        prob_path='icfm',
        noise_like_function=torch.randn_like,
        sigma = 0.0,
        mode='generation',
    ):
        super().__init__()
        self.model = model
        self.prob_path = prob_path
        self.noise_like_function = noise_like_function
        self.sigma = sigma
        self.mode = mode
                
    
    def pad_like_x(self, t, x):
        """Function to reshape the time vector t by the number of dimensions of x."""
        if isinstance(t, (float, int)):
            return t
        return t.reshape(-1, *([1] * (x.dim() - 1)))

    def noise_like(self, x):
        return self.noise_like_function(x)
        
    def compute_mu_t(self, x0, x1, t):
        t = self.pad_like_x(t, x0)
        if self.prob_path=='cfm':
            return t*x1
        elif self.prob_path in ('icfm', 'otcfm'):
            return (1-t)*x0 + t*x1
        else:
            raise NotImplementedError("Only cfm and icfm are supported for the moment.")
        
    def compute_sigma_t(self, x0, x1, t):
        t = self.pad_like_x(t, x0)
        if self.prob_path=='cfm':
            return self.pad_like_x(t*self.sigma - t + 1, x0)
        elif self.prob_path in ('icfm', 'otcfm'):
            return self.pad_like_x(self.sigma, x0)
        else:
            raise NotImplementedError("Only cfm and icfm are supported for the moment.")
    
    def sample_x_t(self, x0, x1, t):
        mu_t = self.compute_mu_t(x0, x1, t)
        sigma_t = self.compute_sigma_t(x0, x1, t)
        return mu_t + sigma_t * torch.randn_like(x1)
        
    def compute_conditional_flow(self, x0, x1, xt, t):
        t = self.pad_like_x(t, x0)
        if self.prob_path=='cfm':
            return (x1 - (1-self.sigma)*xt)/(1-(1-self.sigma)*t)
        elif self.prob_path in ('icfm', 'otcfm'):
            return x1-x0
        else:
            raise NotImplementedError("Only cfm and icfm are supported for the moment.")
        
    def sample_everything_at_random(self, x1, x0=None, position=None):
        if x0 is None:
            x0 = self.noise_like(x1)
        
        t = torch.sigmoid(torch.randn(x1.shape[0])).type_as(x1)
        xt = self.sample_x_t(x0, x1, t)
        ut = self.compute_conditional_flow(x0, x1, xt, t)
        return x0, x1, xt, ut, t, position
    
    def loss(self, vt, ut):
        return torch.mean((vt - ut) ** 2)

    def forward(self, x1, x0=None, position=None):
        x0, x1, xt, ut, t, position = self.sample_everything_at_random(x1, x0, position)
        
        vt = self.model(t, xt, position=position)
        loss = self.loss(vt, ut)
        return loss, vt
    
    def sample(
            self, 
            x0, 
            euler_steps=100,
            position=None,
            return_all_ts=False,
        ):
        xt = x0.clone()
        ts = torch.linspace(0, 1, euler_steps, device=x0.device)
        delta_t = 1/euler_steps
        xt_list = [x0]
        self.model.eval()
        with torch.no_grad():
            for t in ts[:-1]:
                vt = self.model(t.repeat(x0.size(0)), xt, position=position)
                xt = xt+vt*delta_t
                
                if return_all_ts:
                    xt_list.append(xt.clone())

        return xt if not return_all_ts else torch.stack(xt_list, dim=0)
    
    def sample_reverse(
            self, 
            x1, 
            euler_steps=100,
            position=None,
            return_all_ts=False,
        ):
        xt = x1.clone()
        ts = torch.linspace(1, 0, euler_steps, device=x1.device)
        delta_t = 1/euler_steps
        xt_list = [x1]
        self.model.eval()
        with torch.no_grad():
            for t in ts[:-1]:
                vt = self.model(t.repeat(x1.size(0)), xt, position=position)
                xt = xt-vt*delta_t

                if return_all_ts:
                    xt_list.append(xt.clone())

        return xt if not return_all_ts else torch.stack(xt_list, dim=0)
