
import math
import torch 
import torch.nn.functional as F
from torch_utils import persistence 

from einops import rearrange, einsum

import numpy as np


@persistence.persistent_class
class GCMLoss:
    def __init__(
        self,
        label_dropout=0.1,
        **kwargs,
    ):
        super().__init__()
        
        self.label_dropout = label_dropout
        self.sigma_data = 0.5

    def process_labels(self, labels):
        if labels is not None:
            labels_clone = labels.clone()
            mask = torch.rand(labels.shape[0]) < self.label_dropout
            labels_clone[mask] = 0
            return labels_clone, mask
        else:
            return labels, None

    def __call__(
        self,
        net,
        images,
        labels=None,
        augment_pipe=None,
        device=torch.device("cuda"),
        cur_tick=None,
        **kwargs
    ):
        
        images = images.to(device)
        labels = labels.to(device) if labels is not None else None
        
        t0 = net.module.eps
        tT = net.module.T
        r_eps = 0.005 / 3

        s = t0 * torch.ones([images.shape[0], 1, 1, 1], device=images.device)
        
        t_start = s + r_eps
        
        u = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
        t = t_start + (tT - t_start) * u
        
        dt = r_eps * torch.ones_like(t)
        
        r = (t - dt).clamp(min=s)

        # Augmentation if needed
        y, augment_labels = (augment_pipe(images) if augment_pipe is not None else (images, None))
        labels_drop, drop_mask = self.process_labels(labels)

        z = torch.randn_like(y) * self.sigma_data
        xt = (1 - t) * y + t * z
        ut = -y + z

        xr = xt + (r - t) * ut

        # Shared Dropout Mask
        rng_state = torch.cuda.get_rng_state()
        v_ts = net(
            xt,
            t,
            s,
            labels_drop,
            augment_labels=augment_labels,
            use_velocity=True,
        )

        torch.cuda.set_rng_state(rng_state) 
        with torch.no_grad():
            v_rs = net(
                xr,
                r,
                s,
                labels, 
                augment_labels=augment_labels, 
                force_fp32=True,
                use_velocity=True,
            )
            target = xr + (s - r) * v_rs

        v_ts = xt + (s - r) * v_ts.detach() - dt * v_ts

        weight = 1
        
        loss = ((v_ts - target)) ** 2
        loss = torch.sum(loss, dim=[1, 2, 3], keepdim=True)
        loss = torch.sqrt(loss)
        loss = weight * loss

        logs = {
            "r_t_ratio": r[:, 0] / t[:, 0],
            "s_t_ratio": s[:, 0] / t[:, 0],
            "t_r_diff": t[:, 0] - r[:, 0],
            'loss': loss,
        }

        if torch.isnan(loss).any():
            print("Nan in loss")
            loss = torch.nan_to_num(loss)
        
        logs["ts"] = t[:, 0].flatten()

        return loss, logs
