#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 24 18:53:16 2025

@author: zhou.junkai
"""


import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import time
import numpy as np
import random
import os
import torch
import pickle
import torch.utils.data
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
torch.manual_seed(1)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import torch.nn as nn
import audtorch.metrics.functional
device = torch.device('cuda')
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision import datasets, transforms, models
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

    
class ConvAutoencoder_recon(nn.Module):
    def __init__(self, embed_num, channel, height, width):
        super(ConvAutoencoder_recon, self).__init__()
        self.embed_num = embed_num
        self.channel = channel
        self.height = height
        self.width = width
        
        self.encoder = nn.Sequential(
            nn.Conv2d(self.channel, 8, 3, stride=2, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),

            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
        )
        self.encoder_fc = nn.Sequential(
            nn.Linear(32 * (self.height//8) * (self.width//8), 1024),
            nn.ReLU(True),
            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Linear(512, self.embed_num),
        )

        self.decoder_fc = nn.Sequential(
            nn.Linear(self.embed_num, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 32 * (self.height//8) * (self.width//8)),
            nn.ReLU(True),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(8, self.channel, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x_enc = self.encoder(x)
        batch_size = x.size(0)
        x_flat = x_enc.contiguous().view(batch_size, -1)
        latent = self.encoder_fc(x_flat)  

        x_dec_fc = self.decoder_fc(latent)
        x_dec = x_dec_fc.view(batch_size, 32, (self.height//8), (self.width//8)) + x_enc
        x_recon = self.decoder(x_dec)
        return latent, x_recon

class MyDataset_recon(Dataset):
    def __init__(self, data_set):
        self.data_set = data_set
    def __getitem__(self,index):
        return self.data_set[index]
    def __len__(self):
        return len(self.data_set)


class DiscriminatorMLP(nn.Module):
    def __init__(self, img_channels, img_h, img_w):
        super().__init__()
        self.img_size = img_channels * img_h * img_w
        self.flatten_to_2048 = nn.Linear(self.img_size, 2048)

        self.mlp = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(True),
            nn.Linear(512, 32),
            nn.ReLU(True),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        B = x.size(0)
        x_flat = x.view(B, -1)
        z = self.flatten_to_2048(x_flat)
        out = self.mlp(z).squeeze(dim=-1)  # shape (B,)
        return out  # probability of anomaly
