import torch
from typing import Tuple 
import torch.nn.functional as F

class diffusion_Flow:
    def __init__(self, predict_type="v", 
                 logit_normal=True, 
                 timestep_shift=1,
                 mask_training_ratio = 0,
                 mask_training_weight = 0,
                 same_timestep_per_batch = False,
                 mid_v_supervise_weight = 0,
                 sparse_loss_weight = 0,
                 fix_data_noise_pair = False,

                 ):
        assert predict_type in ["v", "x0", "epsilon"]
        self.predict_type = predict_type
        self.timestep_shift = timestep_shift
        self.logit_normal = logit_normal
        self.mask_training_ratio = mask_training_ratio
        self.mask_training_weight = mask_training_weight

        self.same_timestep_per_batch = same_timestep_per_batch
        self.mid_v_supervise_weight = mid_v_supervise_weight
        self.sparse_loss_weight = sparse_loss_weight 
        self.fix_data_noise_pair = fix_data_noise_pair 

        if self.fix_data_noise_pair:
            self.fix_noise = torch.randn(50_000, 3, 32, 32)


    def get_noise(self, x0, batch_idx=None):
        if self.fix_data_noise_pair:
            x1 = self.fix_noise[batch_idx, ...].clone()
            x1 = x1.to(dtype=x0.dtype, device=x0.device)
            return x1
        else:
            return torch.randn_like(x0)
    
    def get_predict(self, x0, noise):
        if self.predict_type == "v":
            return noise - x0
        elif self.predict_type == "x0":
            return x0
        elif self.predict_type == "epsilon":
            return noise

    def get_timestep(self, x0):
        batch_size = x0.shape[0]
        if self.logit_normal:
            t_logit = torch.exp(torch.randn(batch_size, device=x0.device))
            t = t_logit / (t_logit + 1)
        else:
            t = torch.rand(batch_size, device=x0.device)

        t = self.timestep_shift * t / (1 - t + self.timestep_shift * t)
        if self.same_timestep_per_batch:
            t = t[:1].expand(batch_size)
        return t 

    def get_loss(self, model, x0, model_kwargs, batch_idx=None):
        noise = self.get_noise(x0, batch_idx)
        t = self.get_timestep(x0)
        t_expand = t[:, None, None, None]
        xt = (1 - t_expand) * x0 + t_expand * noise
        target = self.get_predict(x0, noise)

        if self.mask_training_ratio == 0:
            pred = model(xt, t*999, **model_kwargs)
            loss_dict = {}

            if self.mid_v_supervise_weight > 0 :
                pred, mid_v_rtn = pred
                loss = ((pred.float() - target.float()) ** 2).reshape(target.shape[0], -1).mean(dim=1)
                loss_dict["loss"] = 0
                loss_dict["diffusion_loss"] = loss
                loss_dict["loss"] += loss_dict["diffusion_loss"]
                for i, mid_v in enumerate(mid_v_rtn):
                    loss_dict[f"mid_loss_{i}"] = ((mid_v.float() - target.float()) ** 2).reshape(target.shape[0], -1).mean(dim=1)
                    loss_dict["loss"] += loss_dict[f"mid_loss_{i}"] * self.mid_v_supervise_weight # 0.01
                return loss_dict


            elif self.sparse_loss_weight > 0:
                pred, affinity_list, index_list = pred
                loss = ((pred.float() - target.float()) ** 2).reshape(target.shape[0], -1).mean(dim=1)
                loss_dict["loss"] = 0
                loss_dict["diffusion_loss"] = loss
                loss_dict["loss"] += loss_dict["diffusion_loss"]

                for i, (affinity, index) in enumerate(zip(affinity_list, index_list)):
                    mask = torch.zeros_like(affinity)
                    mask.scatter_(-1, index, 1)
                    loss_dict[f"sparse_loss_{i}"] = F.l1_loss(affinity, mask)
                    loss_dict["loss"] += loss_dict[f"sparse_loss_{i}"] * self.sparse_loss_weight
                return loss_dict

            else:   
                loss = ((pred.float() - target.float()) ** 2).reshape(target.shape[0], -1).mean(dim=1)
                return {"loss": loss , "diffusion_loss": loss}
        else:
            # # zeros_padding = torch.zeros_like(xt)
            # mask = torch.where(torch.rand_like(xt) < self.mask_training_ratio, 1, 0 )
            # input = xt * (1 - mask)
            # pred = model(input, t*999, **model_kwargs)
            mask_ratio = self.mask_training_ratio
            pred, mask = model(mask_ratio, xt, t*999, **model_kwargs)
            diffusion_loss = (((pred * (1 - mask)).float() - (target * (1 - mask)).float()) ** 2).reshape(target.shape[0], -1).mean(dim=1)
            mae_loss = (((pred * mask).float() - (xt * mask).float()) ** 2).reshape(target.shape[0], -1).mean(dim=1)

            loss = diffusion_loss + mae_loss * self.mask_training_weight
            return {"loss": loss , "diffusion_loss": diffusion_loss, "mae_loss": mae_loss}
