import torch
import numpy as np
import numpy.random as nr
import pandas as pd
import torch.nn.functional as F
from tqdm import tqdm
from torch import nn
from torch.nn.functional import relu
from torch.utils.checkpoint import checkpoint

device = torch.device("cuda:" + str(0) if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)

class TemporalAffineCoupling1D(nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_channels, hidden_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(hidden_channels, input_channels * 2, kernel_size=1)
        )
        nn.init.zeros_(self.net[-1].weight)
        nn.init.zeros_(self.net[-1].bias)

    def forward(self, x):
        x_even = x[:,:,::2]
        x_odd = x[:,:,1::2]

        h = self.net(x_even)
        shift, log_scale = h.chunk(2, dim=1)
        log_scale = torch.clamp(log_scale, -2, 2)
        scale = torch.exp(log_scale)
        y_odd = x_odd * scale + shift

        y = torch.zeros_like(x)
        y[:,:,::2] = x_even
        y[:,:,1::2] = y_odd

        log_det = torch.sum(torch.log(scale), dim=[1, 2])
        return y, log_det

    def inverse(self, y):
        y_even = y[:, :, ::2]
        y_odd = y[:, :, 1::2]

        h = self.net(y_even)
        shift, log_scale = h.chunk(2, dim=1)
        log_scale = torch.clamp(log_scale, -2, 2)
        scale = torch.exp(log_scale)

        x_odd = (y_odd - shift) / scale

        x = torch.zeros_like(y)
        x[:,:,::2] = y_even
        x[:,:,1::2] = x_odd

        return x


class Invertible1x1Conv1D(nn.Module):
    def __init__(self, channels):
        super().__init__()
        # For stable inversion
        # cxc matrix (QR decomposition => Q is orthogonal matrix, R is upper triangular matrix)
        W = torch.qr(torch.randn(channels, channels))[0] # orthogonal matrix Q
        self.weight = nn.Parameter(W)

    def forward(self, x):
        weight = self.weight.unsqueeze(2) # (c, c, 1)
        z = F.conv1d(x, weight) # channel 
        log_det = x.shape[2] * torch.slogdet(self.weight)[1]
        return z, log_det

    def inverse(self, z):
        weight_inv = torch.inverse(self.weight).unsqueeze(2)
        x = F.conv1d(z, weight_inv)
        return x

class Flow(nn.Module):
    def __init__(self, input_channels=1, output_channels=1, num_layers=4):
        super().__init__()
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(Invertible1x1Conv1D(input_channels))
            self.layers.append(TemporalAffineCoupling1D(input_channels, hidden_channels=2))
        self.scale = nn.Parameter(torch.zeros(1, input_channels, 1))

    def forward(self, x):
        log_det_total = torch.zeros(x.size(0), device=x.device)
        z = x
        for layer in self.layers:
            z, log_det = layer(z)
            log_det_total += log_det
        s = torch.exp(self.scale)
        z = z * s
        log_det_total += torch.sum(torch.log(s)) * z.shape[2]
        return z, log_det_total

    def inverse(self, z):
        s = torch.exp(self.scale)
        x = z / s
        for layer in reversed(self.layers):
            x = layer.inverse(x)
        return x

class s_theta(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_channels, input_channels * 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(input_channels * 2, input_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def forward(self, z):
        eps = self.net(z)
        eps = eps + torch.randn_like(eps) * 0.05
        return eps