#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 24 19:25:33 2025

@author: zhou.junkai
"""

import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import time
import numpy as np
import random
import os
import torch
import pickle
import torch.utils.data
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
torch.manual_seed(1)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import torch.nn as nn
import audtorch.metrics.functional
device = torch.device('cuda')
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision import datasets, transforms, models
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import datetime
import glob
from datetime import datetime, timedelta
from tqdm import tqdm
import scipy.io
from pathlib import Path
import os
import matplotlib.pyplot as plt
import scipy.signal 
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
import model
from PIL import Image
import cv2  
from tqdm import tqdm  
import scipy.io
from scipy.io import loadmat
from scipy.ndimage import median_filter
from scipy.ndimage import uniform_filter1d
from sklearn.metrics import roc_auc_score
import seaborn as sns
import torch
import matplotlib.ticker as ticker


class TemporalCubeDataset(Dataset):
   
    def __init__(self, video_array, T=6, mask_strategy="interval"):
        
        assert video_array.ndim == 4
        self.X = video_array
        self.T = T
        self.mask_strategy = mask_strategy
        self.starts = np.arange(0, len(self.X) - T + 1, 1)

    def __len__(self):
        return len(self.starts)

    def _mask_indices(self, T):
        if self.mask_strategy == "interval":
            keep = np.arange(0, T, 2)   
            mask = np.setdiff1d(np.arange(T), keep)
        else:
            idx = np.arange(T)
            np.random.shuffle(idx)
            half = T // 2
            mask = np.sort(idx[:half])
        return mask.astype(int)

    def __getitem__(self, idx):
        s = self.starts[idx]
        cube = self.X[s:s+self.T]                 # (T, C, H, W)
        mask_idx = self._mask_indices(self.T)     # e.g., [1,3,5]
        global_idx = np.arange(s, s+self.T)       
        return (torch.from_numpy(cube).float(),
                torch.from_numpy(mask_idx).long(),
                torch.from_numpy(global_idx).long())



class FrameCNN(nn.Module):
    def __init__(self, in_ch=3, h=80, w=80, d=256):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(in_ch, 16, 3, stride=2, padding=1), nn.BatchNorm2d(16), nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),    nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),    nn.BatchNorm2d(64), nn.ReLU(True),
        ) 
        self.proj = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * (h//8) * (w//8), d),
            nn.ReLU(True)
        )
        self.h, self.w = h, w

    def forward(self, x):  # x: (B, C, H, W)
        feat = self.backbone(x)
        emb  = self.proj(feat)  # (B, d)
        return emb

class PosEncoding1D(nn.Module):
    def __init__(self, T, d):
        super().__init__()
        self.pe = nn.Parameter(torch.randn(1, T, d) * 0.02)

    def forward(self, x):  # (B, T, d)
        return x + self.pe

class TemporalMAE(nn.Module):
    
    def __init__(self, T=6, in_ch=3, h=80, w=80, d=256, nhead=8, nlayers=4):
        super().__init__()
        self.T, self.d, self.h, self.w, self.in_ch = T, d, h, w, in_ch
        self.encoder_f = FrameCNN(in_ch=in_ch, h=h, w=w, d=d)

        self.pos = PosEncoding1D(T=T, d=d)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d, nhead=nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=nlayers)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, d))

        decoder_layer = nn.TransformerEncoderLayer(d_model=d, nhead=nhead, batch_first=True)
        self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=2)

        self.pixel_head = nn.Sequential(
            nn.Linear(d, 64 * (h//8) * (w//8)),
            nn.ReLU(True),
            nn.Unflatten(1, (64, h//8, w//8)),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), nn.ReLU(True),
            nn.ConvTranspose2d(16, in_ch, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, cubes, mask_idx):
        
        B, T, C, H, W = cubes.shape
        assert T == self.T

        frames = cubes.reshape(B*T, C, H, W)
        z = self.encoder_f(frames).reshape(B, T, self.d)
        z = self.pos(z)

        device = z.device
        all_idx = torch.arange(T, device=device).unsqueeze(0).expand(B, T)  # (B,T)
        mask = torch.zeros(B, T, dtype=torch.bool, device=device)
        mask.scatter_(1, mask_idx, True)  
        keep_idx = [(~mask[b]).nonzero(as_tuple=False).squeeze(1) for b in range(B)]
        masked_idx = [mask[b].nonzero(as_tuple=False).squeeze(1) for b in range(B)]

        z_keep = [z[b, idx] for b, idx in enumerate(keep_idx)]              # list of (Tk, d)
        z_keep = nn.utils.rnn.pad_sequence(z_keep, batch_first=True)        # (B, Tk_max, d)
        z_enc  = self.encoder(z_keep)                                       # (B, Tk_max, d)

        masked_tokens = []
        for b in range(B):
            m = masked_idx[b].shape[0]
            masked_tokens.append(self.mask_token.expand(1, m, self.d))
        masked_tokens = torch.cat(masked_tokens, dim=1)  

        context = z_enc.mean(dim=1, keepdim=True)  # (B,1,d)
        context_rep = context.mean(dim=0, keepdim=True)  # (1,1,d)
        context_rep = context_rep.expand_as(masked_tokens)

        dec_in = masked_tokens + context_rep  # (1, sum_m, d)
        dec_out = self.decoder(dec_in)        # (1, sum_m, d)

        rec_pixels = self.pixel_head(dec_out.squeeze(0))  # (sum_m, C, H, W)

        outputs = []
        cursor = 0
        for b in range(B):
            m = masked_idx[b].shape[0]
            out_b = rec_pixels[cursor:cursor+m]
            outputs.append(out_b.unsqueeze(0))
            cursor += m
        recon_masked = torch.cat(outputs, dim=0)
        return recon_masked, masked_idx
    


def safe_batch_pcc(x_flat: torch.Tensor, y_flat: torch.Tensor) -> torch.Tensor:
    
    x_mu = x_flat.mean(dim=1, keepdim=True)
    y_mu = y_flat.mean(dim=1, keepdim=True)
    x_c = x_flat - x_mu
    y_c = y_flat - y_mu
    x_std = x_c.norm(dim=1)  
    y_std = y_c.norm(dim=1)
    denom = x_std * y_std
    mask = denom > 1e-8
    pcc = torch.zeros(x_flat.size(0), device=x_flat.device)
    pcc[mask] = (x_c[mask] * y_c[mask]).sum(dim=1) / denom[mask]
    return pcc

