import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
import math 
from timm.models.vision_transformer import Attention, Mlp

from flow_matching.utils import ModelWrapper
from flow_matching.solver import ODESolver

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=32):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb

class CNNEmbedder(nn.Module):
    """
    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, input_shape, hidden_size, dropout_prob, **block_kwargs):
        super().__init__()
        use_cfg_embedding = dropout_prob > 0

        in_channels, height, width = input_shape
        self.kernel_size = (3, 3)
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels, hidden_size, self.kernel_size),
            nn.LeakyReLU(negative_slope=0.3),
            nn.AvgPool2d(self.kernel_size),
            nn.Conv2d(hidden_size, hidden_size, self.kernel_size),
            nn.LeakyReLU(negative_slope=0.3),
            nn.AvgPool2d(self.kernel_size)
        )
        H, W = self.output_shape(height, width)

        self.proj = nn.Conv2d(hidden_size, hidden_size, (H, W))
        self.act = nn.LeakyReLU(negative_slope=0.3)
        self.flatten = nn.Flatten()

        self.dropout_prob = dropout_prob
    
    def output_shape(self, H, W):
        # conv1 output
        H, W = H - self.kernel_size[0] + 1, W - self.kernel_size[1] + 1
        # pool1 output
        H = int((H - self.kernel_size[0]) / self.kernel_size[0] + 1)
        W = int((W - self.kernel_size[1]) / self.kernel_size[1] + 1)
        # conv2 output
        H, W = H - self.kernel_size[0] + 1, W - self.kernel_size[1] + 1
        # pool2 output
        H = int((H - self.kernel_size[0]) / self.kernel_size[0] + 1)
        W = int((W - self.kernel_size[1]) / self.kernel_size[1] + 1)

        return H, W

    def token_drop(self, labels, force_drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
        else:
            drop_ids = force_drop_ids == 1
        # labels = torch.where(drop_ids, self.hidden_size, labels)
        return labels * (1.0 - drop_ids.unsqueeze(1).float())  # Zero out dropped labels

    def forward(self, x, train, force_drop_ids=None):
        x = self.cnn(x)
        x = self.proj(x)
        x = self.act(x)
        x = self.flatten(x)

        use_dropout = self.dropout_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            x = self.token_drop(x, force_drop_ids)

        return x

class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(
            nn.GELU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x
    
class FinalLayer(nn.Module):
    """
    The final layer of DiT.
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.GELU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x

class DiT(nn.Module):
    def __init__(
        self,
        input_dim=256,
        input_shape=(1, 100, 100),
        num_channels=16,
        num_blocks=4,
        num_heads=1,
        mlp_ratio=4.0,
        class_dropout_prob=0.0,
        mask_size=10
    ):
        super().__init__()

        self.t_embedder = TimestepEmbedder(num_channels)
        self.pos_embed = nn.Parameter(torch.zeros(1, input_dim, num_channels), requires_grad=False)
        pos_embed = get_1d_sincos_pos_embed(num_channels, input_dim)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        self.y_embedder = CNNEmbedder(input_shape, num_channels, class_dropout_prob)

        self.blocks = nn.ModuleList([DiTBlock(num_channels, num_heads, mlp_ratio=mlp_ratio) for _ in range(num_blocks)])

        self.final = FinalLayer(num_channels, 1, 1)
        self.first = nn.Conv1d(1, num_channels, kernel_size=1, padding="same")

        self.mask = torch.ones(input_dim).unsqueeze(0)
        self.mask[:, :mask_size] = 0.0  # mask the first 10 features
    
    def forward(self, x, t, c):
        h = x.unsqueeze(1)  # (B, 1, T)
        if len(c.shape) == 3:
            c = c.unsqueeze(1) # (B, C, H, W)
        h = self.first(h).transpose(1, 2) + self.pos_embed # (B, T, C)
        t = self.t_embedder(t)
        c = self.y_embedder(c, self.training)    # (N, D)
        t = t + c

        for transformer in self.blocks:
            h = transformer(h, t)

        out = self.final(h, t).squeeze(-1)
        return out * self.mask.to(x.device)  # Apply mask to the output
    
class InitializerNet(nn.Module):
    def __init__(self, input_dim=256, mask_size=10):
        super().__init__()
        self.mask = torch.ones(input_dim).unsqueeze(0)
        self.mask[:, :mask_size] = 0  
        self.mask_size = mask_size
        
    def forward(self, z, x, c):
        x = F.pad(x, (0, z.shape[-1] - x.shape[-1]), 'constant', 0.0)  # Pad x to match output length
        return z * self.mask.to(z.device) + x 
    
class WrappedModel(ModelWrapper):
    def forward(self, x, t, model_extras=None):
        t = t.repeat(x.shape[0])
        return self.model(x, t, model_extras)

def get_1d_sincos_pos_embed(embed_dim, input_dim):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    pos = np.arange(input_dim, dtype=np.float32)
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def train_step(model, initializer, batch, path, optim, lr_scheduler, device):
    optim.zero_grad() 
    x_1 = batch['label'].to(device)
    c = batch['feature'].to(device) 
    z = torch.randn_like(x_1).to(device)

    x_0 = initializer(z, x_1[:, :initializer.mask_size], c)

    t = torch.rand(x_1.shape[0]).to(device) 

    # sample probability path
    path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)

    # flow matching l2 loss
    pred = model(path_sample.x_t, path_sample.t, c)
    loss = torch.pow(pred - path_sample.dx_t, 2).mean() 

    # optimizer step
    loss.backward() # backward
    optim.step() # update
    lr_scheduler.step()
    
    return loss.item()

def test_step(model, initializer, batch, path, device):
    x_1 = batch['label'].to(device)
    c = batch['feature'].to(device) 
    z = torch.randn_like(x_1).to(device)

    x_0 = initializer(z, x_1[:, :initializer.mask_size], c)

    t = torch.rand(x_1.shape[0]).to(device) 

    # sample probability path
    path_sample = path.sample(t=t, x_0=x_0, x_1=x_1)

    # flow matching l2 loss
    pred = model(path_sample.x_t, path_sample.t, c)
    loss = torch.pow(pred - path_sample.dx_t, 2).mean() 

    return loss.item()

def train(model, initializer, train_dataloader, test_dataloader, path, optim, lr_scheduler, device, config, validation=True):
    start_time = time.time()
    for i in range(config.num_epochs):
        loss = 0.0
        val_loss = 0.0
        model.train()  # set the model to training mode
        initializer.train()  # set the initializer to training mode
        for batch in train_dataloader:
            loss = train_step(model, initializer, batch, path, optim, lr_scheduler, device)
            loss += loss  # accumulate loss
        if validation:
            model.eval()  # set the model to evaluation mode
            initializer.eval()  # set the initializer to evaluation mode
            for batch in test_dataloader:
                val_loss = test_step(model, initializer, batch, path, device)
                val_loss += val_loss  # accumulate validation loss
            
            # log loss
        if (i+1) % config.print_every == 0:
            elapsed = time.time() - start_time
            print('| iter {:6d} | {:5.2f} ms/step | loss {:8.6f} | val_loss {:8.6f} | lr {:.6f} |' 
                .format(i+1, elapsed*1000/config.print_every, loss, val_loss, lr_scheduler.get_last_lr()[0])) 
            start_time = time.time()

def generate_samples(dataset, model, initializer, device, config, num_samples=10, return_records=False):
    model.eval()  # set the model to evaluation mode
    wrapped_vf = WrappedModel(model).to(device)  # wrap the model to use it with ODESolver
    step_size = 0.001
    x_test = dataset.test_data.label.to(device)
    c_test = dataset.test_data.feature.to(device)  # repeat the conditioning for all samples
    batch_size = x_test.shape[0]  # batch size

    T = torch.linspace(0,1,10)  # sample times
    T = T.to(device=device)
    x_init = torch.randn((batch_size, config.sequence_length), dtype=torch.float32, device=device)
    x_init = initializer(x_init, x_test[:, :initializer.mask_size].to(device=device), c_test)  # initialize the samples
    solver = ODESolver(velocity_model=wrapped_vf)  # create an ODESolver class
    sol = solver.sample(time_grid=T, x_init=x_init, method='dopri5', step_size=step_size, return_intermediates=True, model_extras=c_test)
    x_test = dataset.label_transformation.inverse_transform(dataset.test_data.label.cpu())
    x_pred = dataset.label_transformation.inverse_transform(sol.cpu()[-1])

    records = []
    for rec in sol.cpu():
        records.append(dataset.label_transformation.inverse_transform(rec))
    if return_records:
        return x_test, x_pred, np.array(records)
    else:
        return x_test, x_pred

if __name__ == "__main__":
    # Example usage
    model = DiT(input_dim=256, num_channels=16, num_blocks=4, num_heads=1)
    x = torch.randn(10, 256)  # Batch of 10 samples with 256 features
    t = torch.tensor([0.5]).repeat(10)  # Time embedding for each sample
    c = torch.randn(10, 1, 100, 100)  # Class labels reshaped to (B, C, H, W)
    model(x, t, c)  # Forward pass
    print(summary(model, input_data=(x, t, c), device='cpu'))
