import torch.nn as nn
import torch.nn.functional as F
from utils.utils import *
import network.modal_trans.trans as modal_trans


class DCFlow(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        if args.trans == "UNet":
            self.modal_trans = getattr(modal_trans, args.trans)()
    
        if args.model_flow == "RAFT":
            from network.flow_estimator.RAFT.raft import RAFT
            self.flow_estimator = RAFT(args)

    def forward(self, image_warp, image, mode):
        if mode == "trainAA":
            if self.modal_trans != None:
                pseudo_imgA_warp, pseudo_imgA = self.modal_trans(image_warp), self.modal_trans(image)
                pred_flow_A2self = self.flow_estimator(pseudo_imgA_warp.detach(), pseudo_imgA.detach())
            else:
                pred_flow_A2self = self.flow_estimator(image_warp, image)
            return pred_flow_A2self
        elif mode == "trainBB":
            pred_flow_B2self = self.flow_estimator(image_warp, image)
            return pred_flow_B2self
        elif mode == "trainBA":
            if self.modal_trans != None:
                pseudo_imgA = self.modal_trans(image)
                pred_flow_B2A = self.flow_estimator(image_warp, pseudo_imgA.detach())
            else:
                pseudo_imgA = image
                pred_flow_B2A = self.flow_estimator(image_warp, image)
            return pseudo_imgA, pred_flow_B2A
        elif mode == "trainBA_aug":
            if self.modal_trans != None:
                pseudo_imgA = self.modal_trans(image)
                pred_flow_B2A = self.flow_estimator(image_warp, pseudo_imgA)
            else:
                pseudo_imgA = image
                pred_flow_B2A = self.flow_estimator(image_warp, image)
            return pseudo_imgA, pred_flow_B2A
        elif mode == "test":
            if self.modal_trans != None:
                pseudo_imgA = self.modal_trans(image)
                pred_flow_B2A = self.flow_estimator(image_warp, pseudo_imgA)
            else:
                pseudo_imgA = image
                pred_flow_B2A = self.flow_estimator(image_warp, image)
            return pseudo_imgA, pred_flow_B2A
        else:
            print("ERROR : mode error")