#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 24 19:20:27 2025

@author: zhou.junkai
"""

import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import time
import numpy as np
import random
import os
import torch
import pickle
import torch.utils.data
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
torch.manual_seed(1)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import torch.nn as nn
import audtorch.metrics.functional
device = torch.device('cuda')
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision import datasets, transforms, models
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import datetime
import glob
from datetime import datetime, timedelta
from tqdm import tqdm
import scipy.io
from pathlib import Path
import os
import matplotlib.pyplot as plt
import scipy.signal 
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
from PIL import Image
import cv2  
from tqdm import tqdm  
import scipy.io
from scipy.io import loadmat
from scipy.ndimage import median_filter
from scipy.ndimage import uniform_filter1d
from sklearn.metrics import roc_auc_score
import seaborn as sns
import torch
import matplotlib.ticker as ticker


class SEBlock(nn.Module):
    def __init__(self, ch, r=8):
        super().__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ch, ch//r, 1, bias=False), nn.ReLU(inplace=True),
            nn.Conv2d(ch//r, ch, 1, bias=False), nn.Sigmoid()
        )
    def forward(self, x):
        w = self.fc(x)
        return x * w

class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(ch)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(ch)
        self.se    = SEBlock(ch, r=8)
    def forward(self, x):
        idt = x
        x = F.relu(self.bn1(self.conv1(x)), inplace=True)
        x = self.bn2(self.conv2(x))
        x = self.se(x)
        return F.relu(x + idt, inplace=True)

class ConvLSTMCell(nn.Module):
    def __init__(self, in_ch, hid_ch, k=3, padding=1):
        super().__init__()
        self.hid_ch = hid_ch
        self.conv = nn.Conv2d(in_ch + hid_ch, 4*hid_ch, k, padding=padding)
    def forward(self, x_t, h_c):
        (h, c) = h_c
        if h is None:
            B, C, H, W = x_t.shape
            h = torch.zeros(B, self.hid_ch, H, W, device=x_t.device, dtype=x_t.dtype)
            c = torch.zeros_like(h)
        z = self.conv(torch.cat([x_t, h], dim=1))
        i, f, g, o = torch.chunk(z, 4, dim=1)
        i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
        g = torch.tanh(g)
        c = f * c + i * g
        h = o * torch.tanh(c)
        return h, c

class ConvLSTM(nn.Module):
    def __init__(self, in_ch, hid_ch, steps):
        super().__init__()
        self.cell = ConvLSTMCell(in_ch, hid_ch)
        self.steps = steps
    def forward(self, seq):  # seq: (B,T,C,H,W)
        B, T, C, H, W = seq.shape
        h = c = None
        hs = []
        for t in range(T):
            h, c = self.cell(seq[:, t], (h, c))
            hs.append(h)
        H_last = h                      # (B, hid_ch, H, W)
        H_mean = torch.stack(hs, 1).mean(1)  # (B, hid_ch, H, W)
        return H_last, H_mean


class DASTLite(nn.Module):
    def __init__(self, in_ch=3, base=32, T=6, target="last"):
        super().__init__()
        self.T = T
        self.target = target  # "last" or "center"

        # Encoder: stride-2 × 3, C: 3 -> 32 -> 64 -> 128
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_ch, base, 3, stride=2, padding=1), nn.BatchNorm2d(base), nn.ReLU(True),
            ResBlock(base)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(base, base*2, 3, stride=2, padding=1), nn.BatchNorm2d(base*2), nn.ReLU(True),
            ResBlock(base*2)
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(base*2, base*4, 3, stride=2, padding=1), nn.BatchNorm2d(base*4), nn.ReLU(True),
            ResBlock(base*4), SEBlock(base*4, r=8)
        )

        self.convlstm = ConvLSTM(in_ch=base*4, hid_ch=base*4, steps=T)

        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(base*4, base*2, 4, stride=2, padding=1), nn.BatchNorm2d(base*2), nn.ReLU(True)
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(base*2, base, 4, stride=2, padding=1), nn.BatchNorm2d(base), nn.ReLU(True)
        )
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(base, in_ch, 4, stride=2, padding=1),
            nn.Sigmoid()  
        )

    def encode_frame(self, x):  # x: (B,C,H,W)
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)  # (B, 4*base, H/8, W/8)
        return x1, x2, x3

    def forward(self, x_seq):  # (B,T,C,H,W)
        B, T, C, H, W = x_seq.shape
        xs3 = []
        for t in range(T):
            _, _, x3 = self.encode_frame(x_seq[:, t])
            xs3.append(x3)
        feat_seq = torch.stack(xs3, dim=1)  # (B,T,4*base,H/8,W/8)

        # ConvLSTM
        h_last, h_mean = self.convlstm(feat_seq)
        feat = h_last  

        y = self.dec3(feat)
        y = self.dec2(y)
        y = self.dec1(y)  # (B,C,H,W)
        return y, h_mean 


class VideoWindowDataset(Dataset):
    
    def __init__(self, video_array, T=6, target="last"):
        assert video_array.ndim == 4
        self.X = video_array  
        self.T = T
        self.target = target
        self.starts = np.arange(0, len(self.X) - T + 1, 1)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        s = self.starts[idx]
        seq = self.X[s:s+self.T]            # (T,C,H,W)
        if isinstance(seq, np.ndarray):
            seq = torch.from_numpy(seq).float()
        if self.target == "last":
            t_idx = s + self.T - 1
            y_t = seq[-1]
        else:
            c = self.T // 2
            t_idx = s + c
            y_t = seq[c]
        return seq, y_t, torch.tensor(t_idx, dtype=torch.long)



def _safe_pcc_batch(y_hat, y):
    # y_hat, y: (B, C, H, W) -> (B, D)
    y_hat_f = y_hat.view(y_hat.size(0), -1)
    y_f     = y.view(y.size(0), -1)
    p = audtorch.metrics.functional.pearsonr(y_hat_f, y_f).squeeze(-1)
    p = torch.nan_to_num(p, nan=0.0)
    return p  # (B,)