import math
import torch
import torch.nn.functional as F
from torch import nn
from utils import *




class RFM(nn.Module):

    def __init__(self, model, model_type, target_type, sr_ratio, use_both=False, device= "cuda", debug= False):
        super().__init__()
        self.model = model
        self.device = device
        self.debug = debug
        self.sr_ratio = sr_ratio
        assert target_type in ["s", "t"]
        self.target_type = target_type
        self.model_type = model_type
        self.use_both = use_both
    

    def v_pred(self, x, t, cond, lr, target_mask= None):
        if self.debug is True :
            print("size of x, t, cond : ", x.size(), t.size(), cond[0].size(), cond[1].size())
        if self.model_type != "VP" :
            if self.target_type == "s" : v_hat, cond_hr = self.model(x, t, cond[0])
            else : v_hat, cond_hr = self.model(x, t, cond[1])
        else :
            if self.use_both : v_hat, cs, ct = self.model(x, cond, lr, t, target_mask)
            else : v_hat, cs, ct = self.model(x, cond, lr, t, target_mask, use= self.target_type)
            if self.target_type == "s" : cond_hr = cs
            else : cond_hr = ct
        return v_hat, cond_hr
    

    def forward(self, x_1, cond_lr, lr, target_mask= None):
        x_0 = torch.randn_like(x_1)
        t = torch.rand(x_1.size(0), device= self.device)
        t = torch.sigmoid(t)
        t_reshaped = t.view(-1, 1, 1)
        x_t = t_reshaped * x_1 + (1 - t_reshaped) * x_0
        gt_v = x_1 - x_0
        pred_v, cond_hr = self.v_pred(x_t, t, cond_lr, lr, target_mask)
        return pred_v, gt_v, cond_hr
    
    
    @torch.no_grad()
    def generation(self, x_0, lr, condition, num_timesteps= 4, target_mask= None, match_vol= False):
        B = condition[0].size(0)
        x_t = x_0.clone() 
        for t_idx in range(num_timesteps):
            t_idx = t_idx / num_timesteps  ## scale to 0-1
            t = torch.ones(B, device= self.device) * t_idx
            if self.model_type != "VP" :
                if self.target_type == "s" : cond = condition[0]
                else : cond = condition[1]
                cond = cond.to(self.device)
                predicted_velocity, pred_cond = self.model(x_t, t, cond)
            else :
                cond = [condition[0].to(self.device), condition[1].to(self.device)]
                if self.use_both :
                    predicted_velocity, pred_cond_s, pred_cond_t = self.model(x_t, cond, lr, t, target_mask)
                else :
                    predicted_velocity, pred_cond_s, pred_cond_t = self.model(x_t, cond, lr, t, target_mask, use= self.target_type)
            delta_t = torch.ones(B, device= self.device) / num_timesteps
            x_t = x_t + predicted_velocity * (delta_t.view(B, 1, 1))
        if self.model_type == "VP" :
            if self.target_type == "s" : pred_cond = pred_cond_s
            else : pred_cond = pred_cond_t
        if match_vol :
            x_1 = match_variance(pred_cond, x_t, sr_ratio= self.sr_ratio)
            return x_1, pred_cond
        return x_t, pred_cond
