import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

from multihead_edm.networks import EDMPrecond
from multihead_edm.loss import MultiHeadEDMLoss
from edm.sampler import edm_sampler

class MultiHeadEDM(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

        # network
        self.net = EDMPrecond(**kwargs)
        self.ema = copy.deepcopy(self.net).eval().requires_grad_(False)

        # loss fuction
        self.loss = MultiHeadEDMLoss(**kwargs)

    def loss_func(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        return self.loss(self.net, x, y).mean()

    def update_ema(self, step: int, train_batch_size: int, ema_halflife_kimg: int, ema_rampup_ratio: float=None, **kwargs):
        ema_halflife_nimg = ema_halflife_kimg * 1000
        if ema_rampup_ratio is not None:
            ema_halflife_nimg = min(ema_halflife_nimg, step * train_batch_size * ema_rampup_ratio)
        ema_beta = 0.5 ** (train_batch_size / max(ema_halflife_nimg, 1e-8))
        for p_ema, p_net in zip(self.ema.parameters(), self.net.parameters()):
            p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta))

    def forward(self, x: torch.Tensor, sigma: torch.Tensor, y: torch.Tensor, use_ema=False):
        if use_ema:
            outputs = self.ema(x, sigma)
        else:
            outputs = self.net(x, sigma)
        
        outputs = outputs.view(outputs.shape[0], y.size(1), -1, outputs.shape[2], outputs.shape[3])

        outputs = outputs[torch.arange(outputs.shape[0]), torch.argmax(y, dim=1).long(), :, :, :]
        return outputs
        
    @torch.inference_mode()
    def inference(self, y: torch.Tensor, num_steps: int=18, latents: torch.Tensor = None) -> torch.Tensor:
        if latents is None:
            latents = torch.torch.randn((y.shape[0], self.net.img_channels, self.net.img_resolution, self.net.img_resolution), device=y.device)
        images = edm_sampler(self, latents, y, num_steps=num_steps)
        return images.detach().cpu()