import torch
import torch.nn as nn



class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = nn.BatchNorm2d(in_channels)
        self.conv1 = torch.nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv2d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)

    def forward(self, x):
        h = x
        h = self.norm1(h)
        h = h.relu()
        h = self.conv1(h)

        h = self.norm2(h)
        h = h.relu()
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h



def fix_phase(x):
    abs_x = abs(x)
    x = (x<0)*(abs_x + 2*(torch.pi-abs_x)) + (x>=0)*x
    return x

def fix_phase2(x):
    x = (x>torch.pi)*(x-2*torch.pi) + (x<=torch.pi)*x
    return x


class CNN_HEAD(nn.Module):
    def __init__(self):
        super().__init__()
        nch=32
        out_ch=2
        ch_mult=[1,2,2]
        ch, out_ch, ch_mult = nch, out_ch, tuple(ch_mult)
        num_res_blocks = 3
        self.resnet = []
        self.resnet.append(torch.nn.Conv2d(out_ch,
                                       nch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1))
        self.resnet.append(torch.nn.ReLU())
        block_in = nch
        block_out = nch
        for i in range(len(ch_mult)):
            self.resnet.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                        dropout=0))
            block_in = block_out
            block_out = block_out * ch_mult[i]
        self.resnet.append(torch.nn.Conv2d(block_in,
                                       out_ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1))
        self.resnet = nn.Sequential(*self.resnet)
    def forward(self, x):
        return self.resnet(x)


class network_F(nn.Module):
    def __init__(self):
        super(network_F, self).__init__()
        self.mag_pred = CNN_HEAD()
        self.phase_pred = CNN_HEAD()
    def forward(self,a,b):
        phasors = torch.cat((a.unsqueeze(1),b.unsqueeze(1)),1)
        mag = phasors.abs()
        mag = self.mag_pred(mag).relu()
        phase = torch.cat((fix_phase(a.angle()),fix_phase((a*b.conj()).angle()).unsqueeze(1)),1)
        phase = self.phase_pred(phase).relu()
        phase = fix_phase2(phase)
        phase_alpha =  phase[:,0,:,:].contiguous()
        phase_beta =  phase[:,1,:,:].contiguous()
        mag_alpha = mag[:,0,:,:].contiguous()
        mag_beta = mag[:,1,:,:].contiguous()
        alpha = torch.complex(mag_alpha*torch.cos(phase_alpha),mag_alpha*torch.sin(phase_alpha))
        beta = torch.complex(mag_beta*torch.cos(phase_beta),mag_beta*torch.sin(phase_beta))
        out = a*alpha + beta*b
        return out