
import torch
import torch.nn as nn
import torch.nn.functional as F

from .Q2Q import Q2QNet
from .Q2R import Q2RNet



def create_model():

    img_channel = 1
    out_channel = 3
    
    qbe_channels = 4
    qb_channels = 3
    q2q_dim = 32
    q2r_dim = 48
    
    enc_blks = [2, 2, 2]
    middle_blk_num = 2
    dec_blks = [2, 2, 2]
    config=[6,6,6,6,6,6,6]
    num_heads = [1,2,4]
    
    net = tsanet(img_channel=img_channel, qbe_channels=qbe_channels, q2q_dim=q2q_dim, middle_blk_num=middle_blk_num, enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, fuse_before_downsample=True, num_heads=num_heads, \
        qb_channel=qb_channels, out_nc=out_channel, q2r_config=config, q2r_dim=q2r_dim, drop_path_rate=0.0, input_resolution=256)
    return net



class tsanet(nn.Module):
    def __init__(self, img_channel=1, qbe_channels=4, q2q_dim=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], fuse_before_downsample=True, num_heads=[1,2,4], \
        qb_channel=3, out_nc=3, q2r_config=[2,2,2,2,2,2,2], q2r_dim=32, drop_path_rate=0.0, input_resolution=256):
        super(tsanet, self).__init__()
        
        self.q2q = Q2QNet(img_channel=img_channel,qb_channels=qbe_channels, width=q2q_dim, middle_blk_num=middle_blk_num,
                      enc_blk_nums=enc_blk_nums, dec_blk_nums=dec_blk_nums, fuse_before_downsample=fuse_before_downsample, num_heads=num_heads)
        self.q2r = Q2RNet(in_nc=img_channel,qb_channel=qb_channel,out_nc=out_nc, config=q2r_config, dim=q2r_dim, drop_path_rate=drop_path_rate,input_resolution=input_resolution)
        
    def forward(self, x, qbe_map, qb_map):
        q_fixed = self.q2q(x, qbe_map)
        r = self.q2r(q_fixed, qb_map)
        return [r, q_fixed]


    