import time
import torch
import torch.nn as nn

from nns import get_1d_sincos_pos_embed, DiTBlock, FinalLayer

class CNNEmbedder(nn.Module):
    """
    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, input_shape, hidden_size, output_size, dropout_prob, **block_kwargs):
        super().__init__()
        use_cfg_embedding = dropout_prob > 0

        in_channels, height, width = input_shape
        self.kernel_size = (3, 3)
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels, hidden_size, self.kernel_size),
            nn.LeakyReLU(negative_slope=0.3),
            nn.AvgPool2d(self.kernel_size),
            nn.Conv2d(hidden_size, hidden_size, self.kernel_size),
            nn.LeakyReLU(negative_slope=0.3),
            nn.AvgPool2d(self.kernel_size)
        )
        H, W = self.output_shape(height, width)

        self.act = nn.LeakyReLU(negative_slope=0.3)
        self.proj = nn.Conv2d(hidden_size, 1, (1, 1))
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(H*W, output_size)

        self.dropout_prob = dropout_prob
    
    def output_shape(self, H, W):
        # conv1 output
        H, W = H - self.kernel_size[0] + 1, W - self.kernel_size[1] + 1
        # pool1 output
        H = int((H - self.kernel_size[0]) / self.kernel_size[0] + 1)
        W = int((W - self.kernel_size[1]) / self.kernel_size[1] + 1)
        # conv2 output
        H, W = H - self.kernel_size[0] + 1, W - self.kernel_size[1] + 1
        # pool2 output
        H = int((H - self.kernel_size[0]) / self.kernel_size[0] + 1)
        W = int((W - self.kernel_size[1]) / self.kernel_size[1] + 1)

        return H, W

    def token_drop(self, labels, force_drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
        else:
            drop_ids = force_drop_ids == 1
        # labels = torch.where(drop_ids, self.hidden_size, labels)
        return labels * (1.0 - drop_ids.unsqueeze(1).float())  # Zero out dropped labels

    def forward(self, x, train, force_drop_ids=None):
        x = self.cnn(x)
        x = self.proj(x)
        x = self.act(x)
        x = self.flatten(x)
        x = self.linear(x)

        use_dropout = self.dropout_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            x = self.token_drop(x, force_drop_ids)

        return x

class TrF(nn.Module):
    def __init__(
        self,
        input_dim=256,
        input_shape=(1, 100, 100),
        num_channels=16,
        num_blocks=2,
        num_heads=1,
        mlp_ratio=4.0,
        class_dropout_prob=0.0,
    ):
        super().__init__()

        self.pos_embed = nn.Parameter(torch.zeros(1, input_dim, num_channels), requires_grad=False)
        pos_embed = get_1d_sincos_pos_embed(num_channels, input_dim)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        self.y_embedder = CNNEmbedder(input_shape, num_channels, input_dim, class_dropout_prob)
        self.blocks = nn.ModuleList([DiTBlock(num_channels, num_heads, mlp_ratio=mlp_ratio) for _ in range(num_blocks)])
        
        self.final = FinalLayer(num_channels, 1, 1)
        self.x_to_t = nn.Linear(input_dim, num_channels)

    def forward(self, c):
        if len(c.shape) == 3:
            c = c.unsqueeze(1) # (B, C, H, W)
        x = self.y_embedder(c, self.training)    # (N, D)
        h = x.unsqueeze(1).transpose(1, 2) + self.pos_embed # (B, T, C)
        t = self.x_to_t(x)

        for transformer in self.blocks:
            h = transformer(h, t)

        out = self.final(h, t).squeeze(-1)
        return out 


def train_step(model, batch, optim, lr_scheduler, device):
    optim.zero_grad() 
    x_1 = batch['label'].to(device)
    c = batch['feature'].to(device) 

    pred = model(c)
    loss = torch.pow(pred - x_1, 2).mean() 

    # optimizer step
    loss.backward() # backward
    optim.step() # update
    lr_scheduler.step()
    
    return loss.item()

def test_step(model, batch, device):
    x_1 = batch['label'].to(device)
    c = batch['feature'].to(device) 

    pred = model(c)
    loss = torch.pow(pred - x_1, 2).mean() 

    return loss.item()

def train(model, train_dataloader, test_dataloader, optim, lr_scheduler, device, config, validation=True):
    start_time = time.time()
    for i in range(config.num_epochs):
        loss = 0.0
        val_loss = 0.0
        model.train()  # set the model to training mode
        for batch in train_dataloader:
            loss = train_step(model, batch, optim, lr_scheduler, device)
            loss += loss  # accumulate loss
        if validation:
            model.eval()  # set the model to evaluation mode
            for batch in test_dataloader:
                val_loss = test_step(model, batch, device)
                val_loss += val_loss  # accumulate validation loss
            
            # log loss
        if (i+1) % config.print_every == 0:
            elapsed = time.time() - start_time
            print('| iter {:6d} | {:5.2f} ms/step | loss {:8.6f} | val_loss {:8.6f} | lr {:.6f} |' 
                .format(i+1, elapsed*1000/config.print_every, loss, val_loss, lr_scheduler.get_last_lr()[0])) 
            start_time = time.time()

def generate_samples(dataset, model, device):
    model.eval()  # set the model to evaluation mode
    x_test = dataset.test_data.label.to(device)
    c_test = dataset.test_data.feature.to(device)  # repeat the conditioning for all samples
    pred = model(c_test)

    x_test = dataset.label_transformation.inverse_transform(dataset.test_data.label.cpu())
    x_pred = dataset.label_transformation.inverse_transform(pred.cpu())
    return x_test, x_pred
