#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 24 19:22:36 2025

@author: zhou.junkai
"""

import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import time
import numpy as np
import random
import os
import torch
import pickle
import torch.utils.data
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
torch.manual_seed(1)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import torch.nn as nn
import audtorch.metrics.functional
device = torch.device('cuda')
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision import datasets, transforms, models
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import datetime
import glob
from datetime import datetime, timedelta
from tqdm import tqdm
import scipy.io
from pathlib import Path
import os
import matplotlib.pyplot as plt
import scipy.signal 
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
from PIL import Image
import cv2  
from tqdm import tqdm  
import scipy.io
from scipy.io import loadmat
from scipy.ndimage import median_filter
from scipy.ndimage import uniform_filter1d
from sklearn.metrics import roc_auc_score
import seaborn as sns
import torch
import matplotlib.ticker as ticker


class ConvGRUCell(nn.Module):
    def __init__(self, in_ch, hid_ch, k=3, padding=1, dilation=1):
        super().__init__()
        self.hid_ch = hid_ch
        self.zr = nn.Conv2d(in_ch + hid_ch, 2 * hid_ch, k, padding=padding, dilation=dilation)
        self.hc = nn.Conv2d(in_ch + hid_ch, hid_ch, k, padding=padding, dilation=dilation)

    def forward(self, x_t, h_prev):
        if h_prev is None:
            B, C, H, W = x_t.shape
            h_prev = torch.zeros(B, self.hid_ch, H, W, device=x_t.device, dtype=x_t.dtype)
        # xh = torch.cat([x_t, h_prev], dim=1)
        # zr = self.zr(xh)
        # z, r = torch.sigmoid(torch.chunk(zr, 2, dim=1))
        # h_hat = torch.tanh(self.hc(torch.cat([x_t, r * h_prev], dim=1)))
        # h = (1 - z) * h_prev + z * h_hat
        # return h
    
        xh = torch.cat([x_t, h_prev], dim=1)
        zr = self.zr(xh)
        
        z, r = torch.chunk(zr, 2, dim=1)
        z = torch.sigmoid(z)
        r = torch.sigmoid(r)
        
        h_hat = torch.tanh(self.hc(torch.cat([x_t, r * h_prev], dim=1)))
        h = (1 - z) * h_prev + z * h_hat
        return h


class ConvGRU(nn.Module):
    def __init__(self, in_ch, hid_ch, steps, dilation=1):
        super().__init__()
        self.cell = ConvGRUCell(in_ch, hid_ch, dilation=dilation, padding=dilation)
        self.steps = steps

    def forward(self, seq):  # (B,T,C,H,W)
        h = None
        outs = []
        for t in range(seq.size(1)):
            h = self.cell(seq[:, t], h)
            outs.append(h)
        H_last = h
        H_mean = torch.stack(outs, dim=1).mean(1)
        return H_last, H_mean  
    
    
# ---------- ROADMAP-lite ----------
class RoadmapLite(nn.Module):
    
    def __init__(self, in_ch=3, base=32, T=6):
        super().__init__()
        self.T = T
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_ch, base, 3, stride=2, padding=1), nn.BatchNorm2d(base), nn.ReLU(True),
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(base, base*2, 3, stride=2, padding=1), nn.BatchNorm2d(base*2), nn.ReLU(True),
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(base*2, base*4, 3, stride=2, padding=1), nn.BatchNorm2d(base*4), nn.ReLU(True),
        )

        self.gru1 = ConvGRU(in_ch=base*4, hid_ch=base*4, steps=T, dilation=1)
        self.gru2 = ConvGRU(in_ch=base*4, hid_ch=base*4, steps=T, dilation=2)
        self.gru3 = ConvGRU(in_ch=base*4, hid_ch=base*4, steps=T, dilation=4)

        self.fuse_logits = nn.Parameter(torch.zeros(3))  

        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(base*4, base*2, 4, stride=2, padding=1), nn.BatchNorm2d(base*2), nn.ReLU(True),
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(base*2, base, 4, stride=2, padding=1), nn.BatchNorm2d(base), nn.ReLU(True),
        )
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(base, in_ch, 4, stride=2, padding=1),
            nn.Sigmoid()  
        )

    def _encode_frame(self, x):  # (B,C,H,W)
        x = self.enc1(x); x = self.enc2(x); x = self.enc3(x)
        return x  # (B, 4*base, H/8, W/8)

    def forward(self, x_seq):  # (B,T,C,H,W)
        feats = [self._encode_frame(x_seq[:, t]) for t in range(x_seq.size(1))]
        feat_seq = torch.stack(feats, dim=1)  # (B,T,4*base,H/8,W/8)

        _, f1 = self.gru1(feat_seq)
        _, f2 = self.gru2(feat_seq)
        _, f3 = self.gru3(feat_seq)

        w = torch.softmax(self.fuse_logits, dim=0)  # (3,)
        bottleneck = w[0]*f1 + w[1]*f2 + w[2]*f3   # (B,4*base,H/8,W/8)

        y = self.dec3(bottleneck); y = self.dec2(y); y = self.dec1(y)
        emb = torch.cat([F.adaptive_avg_pool2d(f, 1) for f in (f1,f2,f3)], dim=1).flatten(1)
        return y, emb


class VideoWindowDataset(Dataset):
    def __init__(self, video_array, T=6, target="last"):
        self.X = video_array  
        self.T = T
        self.target = target
        self.starts = np.arange(0, len(self.X) - T + 1, 1)

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx):
        s = self.starts[idx]
        seq = self.X[s:s+self.T]  # (T,C,H,W)
        if isinstance(seq, np.ndarray):
            seq = torch.from_numpy(seq).float()
        if self.target == "last":
            y = seq[-1]; gid = s + self.T - 1
        else:
            c = self.T // 2
            y = seq[c]; gid = s + c
        return seq, y, torch.tensor(gid, dtype=torch.long)


def _pcc_batch(yhat, y):
    B = yhat.size(0)
    yh = yhat.view(B, -1)
    yt = y.view(B, -1)
    p = audtorch.metrics.functional.pearsonr(yh, yt).squeeze(-1)
    return torch.nan_to_num(p, nan=0.0)