#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 24 19:23:40 2025

@author: zhou.junkai
"""

import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import time
import numpy as np
import random
import os
import torch
import pickle
import torch.utils.data
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
torch.manual_seed(1)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import torch.nn as nn
import audtorch.metrics.functional
device = torch.device('cuda')
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision import datasets, transforms, models
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import datetime
import glob
from datetime import datetime, timedelta
from tqdm import tqdm
import scipy.io
from pathlib import Path
import os
import matplotlib.pyplot as plt
import scipy.signal 
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
from PIL import Image
import cv2  
from tqdm import tqdm  
import scipy.io
from scipy.io import loadmat
from scipy.ndimage import median_filter
from scipy.ndimage import uniform_filter1d
from sklearn.metrics import roc_auc_score
import seaborn as sns
import torch
import matplotlib.ticker as ticker
import mle_model
import baseline_gcl
import baseline_dast
import baseline_roadmap
import baseline_tmae


def aue_mle_recon(combo, sigma, lambd, embed_num, batchsize, channel, height, width, command):
    
    rc_dict = {
        'RC': {ep: None for ep in range(1, 71)},
        'MSE': {ep: None for ep in range(1, 71)},
        'EMBED': {ep: None for ep in range(5, 71, 5)},
    }
    
    data_set = mle_model.MyDataset_recon(combo)
    trainloader = DataLoader(dataset=data_set, batch_size=batchsize, shuffle=True)
    evaluloader = DataLoader(dataset=data_set, batch_size=500, shuffle=False)
    model = mle_model.ConvAutoencoder_recon(embed_num, channel, height, width).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=0)
    mse_loss = nn.MSELoss()
    mle_loss = mle_model.OptimizedMLELoss(sigma)
    
    for epoch in range(1, 71):
        start_time = time.time()  # Start timing
        model.train()
        for index, x in enumerate(trainloader):
            x = x.to(device)
            encoded, decoded = model(x)
            mseloss = mse_loss(decoded, x)
            mreloss = mle_loss(encoded)
            loss = mseloss + lambd * mreloss
            # print(mseloss.item(), 1e-1 * mreloss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        epoch_rc = []
        epoch_mse = []
        embed = []
        
        model.eval()
        for index, x in enumerate(evaluloader):
            x = x.to(device)
            encoded, decoded = model(x)
            
            x_flat = x.view(x.shape[0], -1)
            decoded_flat = decoded.view(decoded.shape[0], -1)
            
            x_std = x_flat.std(dim=1)
            decoded_std = decoded_flat.std(dim=1)
            
            valid_mask = (x_std > 1e-6) & (decoded_std > 1e-6)
            
            if valid_mask.sum() > 0:
                rc_valid = audtorch.metrics.functional.pearsonr(
                    decoded_flat[valid_mask], x_flat[valid_mask]
                )
                rc_valid = rc_valid.squeeze().detach().cpu().numpy()
            
                rc_valid = np.nan_to_num(rc_valid, nan=0.0)
            else:
                rc_valid = np.zeros(x.shape[0])
            
            epoch_rc.append(rc_valid)
        
            mse = torch.mean((decoded.view(decoded.shape[0], -1) - 
                              x.view(x.shape[0], -1)) ** 2, dim=1)
            mse = mse.detach().cpu().numpy()
            epoch_mse.append(mse)
            
            if epoch in rc_dict['EMBED'] and command == "True":
                embed.append(encoded.cpu().detach().numpy())
        
        epoch_rc = np.concatenate(epoch_rc, axis=0)
        rc_dict['RC'][epoch] = epoch_rc
        
        epoch_mse = np.concatenate(epoch_mse, axis=0)
        rc_dict['MSE'][epoch] = epoch_mse
        
        if epoch in rc_dict['EMBED'] and command == "True" and len(embed) > 0:
            embed = np.concatenate(embed, axis=0)
            tsne = TSNE(n_components=2, random_state=42)
            embed = embed.reshape(-1, embed_num)
            embed = tsne.fit_transform(embed)
            rc_dict['EMBED'][epoch] = embed
        
        epoch_time = time.time() - start_time  # Compute elapsed time
        print(f"Sigma {sigma} || Lambda {lambd} || Epoch {epoch} || Mean RC {np.mean(epoch_rc)} || Mean MSE {np.mean(epoch_mse)} || Time: {epoch_time:.2f} sec")

    return rc_dict    


def cae_gcl_recon(combo, embed_num, batchsize, channel, height, width, command):
    
    rc_dict = {
        'RC': {ep: None for ep in range(1, 71)},
        'MSE': {ep: None for ep in range(1, 71)},
        'EMBED': {ep: None for ep in range(5, 71, 5)},
    }
    

    data_set = baseline_gcl.MyDataset_recon(combo)
    trainloader = DataLoader(dataset=data_set, batch_size=batchsize, shuffle=True)
    evaluloader = DataLoader(dataset=data_set, batch_size=500, shuffle=False)
    generator = baseline_gcl.ConvAutoencoder_recon(embed_num, channel, height, width).to(device) 
    discriminator = baseline_gcl.DiscriminatorMLP(channel, height, width).to(device)
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=5e-4, weight_decay=0)
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=5e-4, weight_decay=0)
    mse_loss = nn.MSELoss(reduction='none')
    bce_loss = nn.BCELoss()
    
    for epoch in range(1, 71):
        start_time = time.time()
        generator.train(); discriminator.train()
        for index, x in enumerate(trainloader):
            # generator forward
            x = x.to(device)
            encoded, recon = generator(x)
            mseloss = mse_loss(recon, x).mean(dim=[1,2,3]) 
            err_mean = mseloss.mean(); err_std = mseloss.std(unbiased=False)
            g_threshold = err_mean + err_std
            g_pseduo_label = (mseloss >= g_threshold).float()
            
            # train discriminator
            d_pred = discriminator(x)
            bceloss = bce_loss(d_pred, g_pseduo_label.to(device))
            d_optimizer.zero_grad()
            bceloss.backward()
            d_optimizer.step()
            
            # discriminator forward
            with torch.no_grad():
                d_pred = discriminator(x)
            d_mean = d_pred.mean(); d_std = d_pred.std(unbiased=False)
            d_threshold = d_mean + 0.1 * d_std
            d_pseudo_label = (d_pred >= d_threshold).float()
            
            # train generator
            neg_target = x.clone()
            # neg_target[d_pseudo_label == 1] = 1.0
            mask = d_pseudo_label.bool()
            neg_target[mask] = torch.ones_like(neg_target[mask])
            encoded, recon = generator(x)
            mseloss = mse_loss(recon, neg_target).mean(dim=[1,2,3]) 
            mseloss = mseloss.mean()
            g_optimizer.zero_grad()
            mseloss.backward()
            g_optimizer.step()
        
        epoch_rc = []
        epoch_mse = []
        embed = []
        
        generator.eval(); discriminator.eval()
        for index, x in enumerate(evaluloader):
            x = x.to(device)
            encoded, decoded = generator(x)
            
            x_flat = x.view(x.shape[0], -1)
            decoded_flat = decoded.view(decoded.shape[0], -1)
            
            x_std = x_flat.std(dim=1)
            decoded_std = decoded_flat.std(dim=1)
            
            valid_mask = (x_std > 1e-6) & (decoded_std > 1e-6)
            
            if valid_mask.sum() > 0:
                rc_valid = audtorch.metrics.functional.pearsonr(
                    decoded_flat[valid_mask], x_flat[valid_mask]
                )
                rc_valid = rc_valid.squeeze().detach().cpu().numpy()
            
                rc_valid = np.nan_to_num(rc_valid, nan=0.0)
            else:
                rc_valid = np.zeros(x.shape[0])
            
            epoch_rc.append(rc_valid)
        
            # print(decoded.shape, label_batch.shape)
            mse = torch.mean((decoded.view(decoded.shape[0], -1) - 
                              x.view(x.shape[0], -1)) ** 2, dim=1)
            mse = mse.detach().cpu().numpy()
            epoch_mse.append(mse)
                    
            if epoch in rc_dict['EMBED'] and command == "True":
                embed.append(encoded.cpu().detach().numpy())
            
        epoch_rc = np.concatenate(epoch_rc, axis=0)
        rc_dict['RC'][epoch] = epoch_rc
        
        epoch_mse = np.concatenate(epoch_mse, axis=0)
        rc_dict['MSE'][epoch] = epoch_mse
        
        if epoch in rc_dict['EMBED'] and command == "True" and len(embed) > 0:
            embed = np.concatenate(embed, axis=0)
            tsne = TSNE(n_components=2, random_state=42)
            embed = embed.reshape(-1, embed_num)
            embed = tsne.fit_transform(embed)
            rc_dict['EMBED'][epoch] = embed
        
        epoch_time = time.time() - start_time  # Compute elapsed time
        print(f"Epoch {epoch} || Mean RC {np.mean(epoch_rc)} || Mean MSE {np.mean(epoch_mse)} || Time: {epoch_time:.2f} sec")    

    return rc_dict


def dast_recon(combo1, combo2, T, batchsize, channel, height, width,
               epochs=70, target="last", command="True", device="cuda"):
    
    N1, C1, H1, W1 = combo1.shape
    assert C1 == channel and H1 == height and W1 == width
    N2, C2, H2, W2 = combo2.shape
    assert C2 == channel and H2 == height and W2 == width

    ds1 = baseline_dast.VideoWindowDataset(combo1, T=T, target=target)
    ds2 = baseline_dast.VideoWindowDataset(combo2, T=T, target=target)
    trainloader = DataLoader(ds1, batch_size=batchsize, shuffle=True, drop_last=True)
    evalloader  = DataLoader(ds2, batch_size=batchsize*2, shuffle=False)

    model = baseline_dast.DASTLite(in_ch=channel, base=32, T=T, target=target).to(device)
    opt   = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=0)
    mse_loss = nn.MSELoss()

    rc_dict = {
        'RC':   {ep: None for ep in range(1, epochs+1)},
        'MSE':  {ep: None for ep in range(1, epochs+1)},
        'EMBED':{ep: None for ep in range(5, epochs+1, 5)},
    }

    for epoch in range(1, epochs+1):
        t0 = time.time()
        # ---------- train ----------
        model.train()
        for x_seq, y_t, _ in trainloader:
            x_seq = x_seq.to(device)  # (B,T,C,H,W)
            y_t   = y_t.to(device)    # (B,C,H,W)
            y_hat, _ = model(x_seq)
            loss = mse_loss(y_hat, y_t)
            opt.zero_grad()
            loss.backward()
            opt.step()

        # ---------- eval ----------
        model.eval()
        per_frame_rc  = [[] for _ in range(N2)]
        per_frame_mse = [[] for _ in range(N2)]
        embed_collect = []

        with torch.no_grad():
            for x_seq, y_t, gidx in evalloader:
                x_seq = x_seq.to(device); y_t = y_t.to(device); gidx = gidx.to(device)
                y_hat, h_mean = model(x_seq)        # y_hat:(B,C,H,W), h_mean:(B,Ch,H/8,W/8)

                pcc = baseline_dast._safe_pcc_batch(y_hat, y_t)   # (B,)
                mse = ((y_hat - y_t)**2).flatten(1).mean(dim=1)  # (B,)

                for k in range(y_hat.size(0)):
                    gi = int(gidx[k].item())
                    per_frame_rc[gi].append(float(pcc[k].item()))
                    per_frame_mse[gi].append(float(mse[k].item()))

                # emb = F.adaptive_avg_pool2d(h_mean, 1).squeeze(-1).squeeze(-1)  # (B, Ch)
                # embed_collect.append(emb.detach().cpu().numpy())

        rc = np.zeros(N2, dtype=np.float32);  ms = np.zeros(N2, dtype=np.float32)
        for i in range(N2):
            if per_frame_rc[i]:
                rc[i] = float(np.mean(per_frame_rc[i]))
                ms[i] = float(np.mean(per_frame_mse[i]))
            else:
                rc[i] = 0.0; ms[i] = 0.0

        rc_dict['RC'][epoch]  = rc
        rc_dict['MSE'][epoch] = ms

        # if (command == "True") and (epoch in rc_dict['EMBED']):
        #     if len(embed_collect) > 0:
        #         emb_mat = np.concatenate(embed_collect, axis=0)
        #         tsne = TSNE(n_components=2, random_state=42)
        #         rc_dict['EMBED'][epoch] = tsne.fit_transform(emb_mat)
        #     else:
        #         rc_dict['EMBED'][epoch] = None

        print(f"[DAST] T {T} || Epoch {epoch} || Mean RC {rc.mean():.6f} || "
              f"Mean MSE {ms.mean():.6f} || Time: {time.time()-t0:.2f}s")

    return rc_dict


def roadmap_recon(combo1, combo2, T, batchsize, channel, height, width,
                  epochs=70, target="last", command="True", device="cuda"):
    
    N1, C1, H1, W1 = combo1.shape
    assert C1==channel and H1==height and W1==width
    N2, C2, H2, W2 = combo2.shape
    assert C2==channel and H2==height and W2==width

    ds_tr1 = baseline_roadmap.VideoWindowDataset(combo1, T=T, target=target)
    ds_tr2 = baseline_roadmap.VideoWindowDataset(combo2, T=T, target=target)
    dl_tr = DataLoader(ds_tr1, batch_size=batchsize, shuffle=True, drop_last=True)
    dl_ev = DataLoader(ds_tr2, batch_size=batchsize*2, shuffle=False)

    net = baseline_roadmap.RoadmapLite(in_ch=channel, base=32, T=T).to(device)
    opt = torch.optim.Adam(net.parameters(), lr=5e-4)
    mse = nn.MSELoss()

    rc_dict = {
        'RC':   {ep: None for ep in range(1, epochs+1)},
        'MSE':  {ep: None for ep in range(1, epochs+1)},
        'EMBED':{ep: None for ep in range(5, epochs+1, 5)},
    }

    for ep in range(1, epochs+1):
        t0 = time.time()
        # ---- train ----
        net.train()
        for x_seq, y_t, _ in dl_tr:
            x_seq = x_seq.to(device)  # (B,T,C,H,W)
            y_t   = y_t.to(device)    # (B,C,H,W)
            y_hat, _ = net(x_seq)
            loss = mse(y_hat, y_t)
            opt.zero_grad(); loss.backward(); opt.step()

        # ---- eval ----
        net.eval()
        per_rc  = [[] for _ in range(N2)]
        per_mse = [[] for _ in range(N2)]
        embeds  = []
        with torch.no_grad():
            for x_seq, y_t, gid in dl_ev:
                x_seq = x_seq.to(device); y_t = y_t.to(device); gid = gid.to(device)
                y_hat, emb = net(x_seq)

                pcc = baseline_roadmap._pcc_batch(y_hat, y_t)                         # (B,)
                mm  = ((y_hat - y_t)**2).flatten(1).mean(dim=1)      # (B,)

                for k in range(y_hat.size(0)):
                    g = int(gid[k].item())
                    per_rc[g].append(float(pcc[k].item()))
                    per_mse[g].append(float(mm[k].item()))

                embeds.append(emb.detach().cpu().numpy())

        rc = np.zeros(N2, dtype=np.float32)
        ms = np.zeros(N2, dtype=np.float32)
        for i in range(N2):
            if per_rc[i]:
                rc[i] = float(np.mean(per_rc[i]))
                ms[i] = float(np.mean(per_mse[i]))
            else:
                rc[i] = 0.0; ms[i] = 0.0

        rc_dict['RC'][ep]  = rc
        rc_dict['MSE'][ep] = ms

        # if (command=="True") and (ep in rc_dict['EMBED']) and len(embeds)>0:
        #     emb_mat = np.concatenate(embeds, axis=0)
        #     tsne = TSNE(n_components=2, random_state=42)
        #     rc_dict['EMBED'][ep] = tsne.fit_transform(emb_mat)

        print(f"[ROADMAP] T {T} | Epoch {ep} | Mean PCC {rc.mean():.6f} | "
              f"Mean MSE {ms.mean():.6f} | {time.time()-t0:.2f}s")

    return rc_dict


def tmae_recon(combo, T, batchsize, channel, height, width,
               epochs=70, mask_strategy="interval",
               embed_num=256,  # 仅用于 t-SNE 可视化的维度（取 encoder d_model）
               command="True", device="cuda"):
    
    N, C, H, W = combo.shape
    assert C == channel and H == height and W == width

    ds = baseline_tmae.TemporalCubeDataset(combo, T=T, mask_strategy=mask_strategy)
    trainloader = DataLoader(ds, batch_size=batchsize, shuffle=True, drop_last=True)
    evalloader = DataLoader(ds, batch_size=batchsize*2, shuffle=False)

    model = baseline_tmae.TemporalMAE(T=T, in_ch=channel, h=height, w=width, d=embed_num).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=0)
    mse_loss = nn.MSELoss()

    rc_dict = {
        'RC':   {ep: None for ep in range(1, epochs+1)},
        'MSE':  {ep: None for ep in range(1, epochs+1)},
        'EMBED':{ep: None for ep in range(5, epochs+1, 5)},
    }

    for epoch in range(1, epochs+1):
        start = time.time()
        # -------------------- train --------------------
        model.train()
        for cubes, mask_idx, _ in trainloader:
            cubes    = cubes.to(device)        # (B,T,C,H,W)
            mask_idx = mask_idx.to(device)     # (B,M)
            recon_masked, masked_idx = model(cubes, mask_idx)  # (B, M_i, C,H,W), list-like idx

            target = []
            for b in range(cubes.size(0)):
                idx_b = masked_idx[b]
                target.append(cubes[b, idx_b])  # (M_i,C,H,W)
            target = torch.stack(target, dim=0)  # (B, M_i, C,H,W)

            loss = mse_loss(recon_masked, target)
            opt.zero_grad()
            loss.backward()
            opt.step()

        # -------------------- eval --------------------
        model.eval()
        per_frame_rc  = [[] for _ in range(N)]
        per_frame_mse = [[] for _ in range(N)]
        embed_collect = []

        with torch.no_grad():
            for cubes, mask_idx, global_idx in evalloader:
                cubes      = cubes.to(device)          # (B,T,C,H,W)
                mask_idx   = mask_idx.to(device)       # (B,M)
                global_idx = global_idx.to(device)     # (B,T)
                recon_masked, masked_idx = model(cubes, mask_idx)

                B = cubes.size(0)
                for b in range(B):
                    idx_b = masked_idx[b]                        # (M_i,)
                    gt_b  = cubes[b, idx_b]                      # (M_i,C,H,W)
                    pr_b  = recon_masked[b]                      # (M_i,C,H,W)

                    gt_f = gt_b.reshape(gt_b.size(0), -1)
                    pr_f = pr_b.reshape(pr_b.size(0), -1)

                    # pcc_vals = audtorch.metrics.functional.pearsonr(pr_f, gt_f).squeeze(-1)
                    pcc_vals = baseline_tmae.safe_batch_pcc(pr_f, gt_f)        # (M_i,)
                    mse_vals = ((pr_f - gt_f)**2).mean(dim=1)    # (M_i,)

                    gidx_b = global_idx[b]                       # (T,)
                    masked_global = gidx_b[idx_b]                # (M_i,)
                    for k, g in enumerate(masked_global.tolist()):
                        per_frame_rc[g].append(float(pcc_vals[k].item()))
                        per_frame_mse[g].append(float(mse_vals[k].item()))

                
        rc_array  = np.zeros(N, dtype=np.float32)
        mse_array = np.zeros(N, dtype=np.float32)
        for i in range(N):
            if len(per_frame_rc[i]) > 0:
                rc_array[i]  = float(np.mean(per_frame_rc[i]))
                mse_array[i] = float(np.mean(per_frame_mse[i]))
            else:
                rc_array[i]  = 0.0
                mse_array[i] = 0.0

        rc_dict['RC'][epoch]  = rc_array
        rc_dict['MSE'][epoch] = mse_array

        if (command == "True") and (epoch in rc_dict['EMBED']):
            if len(embed_collect) > 0:
                embed_mat = np.concatenate(embed_collect, axis=0)  # (n_windows, d)
                tsne = TSNE(n_components=2, random_state=42)
                rc_dict['EMBED'][epoch] = tsne.fit_transform(embed_mat)
            else:
                rc_dict['EMBED'][epoch] = None

        elapse = time.time() - start
        print(f"[TMAE] T {T} || Epoch {epoch} || Mean RC {rc_array.mean():.6f} || "
              f"Mean MSE {mse_array.mean():.6f} || Time: {elapse:.2f}s")

    return rc_dict

