import torch
import numpy as np
import numpy.random as nr
import pandas as pd
import torch.nn.functional as F
from tqdm import tqdm
from torch import nn
from torch.nn.functional import relu
from torch.utils.checkpoint import checkpoint

device = torch.device("cuda:" + str(1) if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)


class CouplingLayer(nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(CouplingLayer, self).__init__()
        self.f = nn.Sequential(
            nn.Conv2d(input_channels // 2, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, input_channels, kernel_size=3, padding=1)
        )
        
    
        
    def forward(self, x):
        x_2, x_1 = x.chunk(2, dim=1)
        z = checkpoint(self.f, x_1) # latent z
        # translation, log-scale
        shift, log_scale = z.chunk(2, dim=1)
        log_scale = torch.clamp(log_scale, min=-7, max=7)
        scale = torch.exp(log_scale)
        y_1 = x_1
        y_2 = x_2*scale + shift
        
        log_det  = torch.sum(torch.log(scale), dim=[1,2,3])

        return torch.cat([y_1, y_2], dim=1), log_det
    
    def inverse(self, y):
        y_2, y_1 = y.chunk(2, dim=1)
        z = checkpoint(self.f, y_1)
        shift, log_scale = z.chunk(2, dim=1)
        log_scale = torch.clamp(log_scale, min=-7, max=7)
        scale = torch.exp(log_scale)
        x_1 = y_1
        x_2 = (y_2 - shift) / scale
        
        return torch.cat([x_1, x_2], dim=1)



class Flow(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(Flow, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.layers = nn.ModuleList([CouplingLayer(input_channels, hidden_channels=input_channels*2) for _ in range(4)])
        self.scale = nn.Parameter(torch.zeros(1, input_channels, 1, 1))

            
    def inverse(self, z):
        x = z / (1+torch.tanh(self.scale))
        for layer in reversed(self.layers):
            x = layer.inverse(x)
        return x
    
    
    def forward(self, x):
        z = x
        log_det_total = torch.zeros(x.size(0), device=x.device)
        for layer in self.layers:
            z, log_det = layer(z)
            log_det_total += log_det
        scale = 1 + torch.tanh(self.scale)
        z = z*scale
        
        # after scaling
        log_det_total += torch.sum(torch.log(scale)) * (z.shape[2] * z.shape[3])
        return z, log_det_total

class s_theta(nn.Module):
    def __init__(self, input_channels):
        super(s_theta, self).__init__()
        self.s = nn.Sequential(
            nn.Conv2d(input_channels, input_channels * 2, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.Conv2d(input_channels * 2, input_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace = True)
        )

    def forward(self, z):
        eps = self.s(z)
        eps = eps + torch.randn_like(eps)

        return eps



