#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 24 19:28:57 2025

@author: zhou.junkai
"""

import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import time
import numpy as np
import random
import os
import torch
import pickle
import torch.utils.data
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
torch.manual_seed(1)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import torch.nn as nn
import audtorch.metrics.functional
device = torch.device('cuda')
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision import datasets, transforms, models
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import datetime
import glob
from datetime import datetime, timedelta
from tqdm import tqdm
import scipy.io
from pathlib import Path
import os
import matplotlib.pyplot as plt
import scipy.signal 
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
from PIL import Image
import cv2  
from tqdm import tqdm  
import scipy.io
from scipy.io import loadmat
from scipy.ndimage import median_filter
from scipy.ndimage import uniform_filter1d
from sklearn.metrics import roc_auc_score
import seaborn as sns
import torch
import matplotlib.ticker as ticker
import main

"""
run for mle-guided cae
"""
# prepare ubnormal data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path)
        if img is None:
            continue
        img = np.transpose(img, (2, 0, 1))  # HWC → CHW
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../ubnormal/normal_scene_4_scenario_3")
abnormal = load_images_to_numpy(".../ubnormal/abnormal_scene_4_scenario_1_fire")
tensor = np.concatenate((normal, abnormal), axis=0)
print(tensor.shape)
combo = tensor

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)
        
tensor = combo

embed_num = 32
RC_dict1 = {}
for sigma in [0.001, 0.01, 0.1, 1.0]:
    RC_dict1[sigma]={}
    for lam in [0.01, 0.1, 1.0, 10]:
        RC_dict1[sigma][lam] = {}
    
for sigma in [0.001, 0.01, 0.1, 1.0]:
    for lam in [0.01, 0.1, 1.0, 10]:
        temp_RC = main.aue_mle_recon(combo=tensor, sigma=sigma, lambd=lam, embed_num=embed_num, 
                                      batchsize=32, channel=3, height=360, width=520, command="True")
        RC_dict1[sigma][lam] = temp_RC
save_dir = Path(".../ubnormal")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'mle_ubnormal.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(RC_dict1, file)
print(f"Dictionary saved successfully to {dict_path}") 

# prepare corridor data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        img = cv2.resize(img, (480, 280), interpolation=cv2.INTER_AREA)
        img = img[np.newaxis, :, :]
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../corridor/000002")
abnormal1 = load_images_to_numpy(".../corridor/000219")
abnormal2 = load_images_to_numpy(".../corridor/000222")
tensor = np.concatenate((normal[:1000], abnormal1[150:250], 
                         abnormal2[-100:]), axis=0)
print(tensor.shape)
combo = tensor

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)
        
tensor = combo

embed_num = 32
RC_dict1 = {}
for sigma in [0.001, 0.01, 0.1, 1.0]:
    RC_dict1[sigma]={}
    for lam in [0.01, 0.1, 1.0, 10]:
        RC_dict1[sigma][lam] = {}
    
for sigma in [0.001, 0.01, 0.1, 1.0]:
    for lam in [0.01, 0.1, 1.0, 10]:
        temp_RC = main.aue_mle_recon(combo=tensor, sigma=sigma, lambd=lam, embed_num=embed_num, 
                                  batchsize=128, channel=1, height=280, width=480, command="True")
        RC_dict1[sigma] = temp_RC
save_dir = Path(".../corridor")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'mle_corridor.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(RC_dict1, file)
print(f"Dictionary saved successfully to {dict_path}") 

# prepare donkeycar data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path)#, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        img = cv2.resize(img, (80, 80))#, interpolation=cv2.INTER_AREA)
        # img = img[np.newaxis, :, :]
        img = np.transpose(img, (2, 0, 1))  # HWC → CHW
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../donkeycar/normal")
abnormal = load_images_to_numpy(".../donkeycar/abnormal_dirtyrain")
normal_tensor = normal[:16000]; abnormal_tensor = abnormal[:2000]
tensor = np.concatenate((normal_tensor, abnormal_tensor), axis=0)
print(tensor.shape)
combo = tensor

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)
        
tensor = combo

embed_num = 32
RC_dict1 = {}
for sigma in [0.001, 0.01, 0.1, 1.0]:
    RC_dict1[sigma]={}
    for lam in [0.01, 0.1, 1.0, 10]:
        RC_dict1[sigma][lam] = {}
    
for sigma in [0.001, 0.01, 0.1, 1.0]:
    for lam in [0.01, 0.1, 1.0, 10]:
        temp_RC = main.aue_mle_recon(combo=tensor, sigma=sigma, lambd=lam, embed_num=embed_num, 
                                  batchsize=32, channel=3, height=80, width=80, command="True")
        RC_dict1[sigma][lam] = temp_RC
save_dir = Path(".../donkeycar")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'mle_donkeycar.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(RC_dict1, file)
print(f"Dictionary saved successfully to {dict_path}") 


"""
run for gcl
"""
# prepare ubnormal data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path)
        if img is None:
            continue
        img = np.transpose(img, (2, 0, 1))  # HWC → CHW
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../ubnormal/normal_scene_4_scenario_3")
abnormal = load_images_to_numpy(".../ubnormal/abnormal_scene_4_scenario_1_fire")
tensor = np.concatenate((normal, abnormal), axis=0)
print(tensor.shape)
combo = tensor

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)
        
tensor = combo

embed_num = 32
RC_dict1 = main.cae_gcl_recon(combo=tensor, embed_num=embed_num, batchsize=32, 
                               channel=3, height=360, width=520, command="True")
save_dir = Path(".../ubnormal")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'gcl_ubnormal.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(RC_dict1, file)
print(f"Dictionary saved successfully to {dict_path}")

# prepare corridor data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        img = cv2.resize(img, (480, 280), interpolation=cv2.INTER_AREA)
        img = img[np.newaxis, :, :]
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../corridor/000002")
abnormal1 = load_images_to_numpy(".../corridor/000219")
abnormal2 = load_images_to_numpy(".../corridor/000222")
tensor = np.concatenate((normal[:1000], abnormal1[150:250], 
                         abnormal2[-100:]), axis=0)
print(tensor.shape)
combo = tensor

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)
        
tensor = combo

embed_num = 32
RC_dict1 = main.cae_gcl_recon(combo=tensor, embed_num=embed_num, batchsize=128, 
                               channel=1, height=280, width=480, command="True")
save_dir = Path(".../corridor")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'gcl_corridor.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(RC_dict1, file)
print(f"Dictionary saved successfully to {dict_path}")

# prepare donkeycar data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path)#, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        img = cv2.resize(img, (80, 80))#, interpolation=cv2.INTER_AREA)
        # img = img[np.newaxis, :, :]
        img = np.transpose(img, (2, 0, 1))  # HWC → CHW
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../donkeycar/normal")
abnormal = load_images_to_numpy(".../donkeycar/abnormal_dirtyrain")
normal_tensor = normal[:16000]; abnormal_tensor = abnormal[:2000]
tensor = np.concatenate((normal_tensor, abnormal_tensor), axis=0)
print(tensor.shape)
combo = tensor

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)
        
tensor = combo

embed_num = 32
RC_dict1 = main.cae_gcl_recon(combo=tensor, embed_num=embed_num, batchsize=32, 
                               channel=3, height=80, width=80, command="True")
save_dir = Path(".../donkeycar")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'gcl_donkeycar.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(RC_dict1, file)
print(f"Dictionary saved successfully to {dict_path}") 


"""
run for dast
"""
# prepare ubnormal data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path)
        if img is None:
            continue
        img = np.transpose(img, (2, 0, 1))  # HWC → CHW
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../ubnormal/normal_scene_4_scenario_3")
abnormal = load_images_to_numpy(".../ubnormal/abnormal_scene_4_scenario_1_fire")
combo = np.concatenate((normal, abnormal), axis=0)
print(combo.shape)

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)

combo2 = combo.numpy(); combo1 = combo2[:451]

embed_num = 32
rc_dict_dast = main.dast_recon(combo1=combo1, combo2=combo2, T=5, batchsize=8, channel=3, height=360, width=520,
               epochs=70, target="last", command="True", device="cuda")
save_dir = Path(".../ubnormal")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'dast_ubnormal.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(rc_dict_dast, file)
print(f"Dictionary saved successfully to {dict_path}") 

# prepare corridor data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        img = cv2.resize(img, (480, 280), interpolation=cv2.INTER_AREA)
        img = img[np.newaxis, :, :]
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../corridor/000002")
abnormal1 = load_images_to_numpy(".../corridor/000219")
abnormal2 = load_images_to_numpy(".../corridor/000222")
combo = np.concatenate((normal[:1000], abnormal1[150:250], 
                         abnormal2[-100:]), axis=0)
print(combo.shape)

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)

combo2 = combo.numpy(); combo1 = combo2[:1000]

embed_num = 32
rc_dict_tmae = main.dast_recon(combo1=combo1, combo2=combo2, T=5, batchsize=8, channel=1, height=280, width=480,
               epochs=70, target="last", command="True", device="cuda")
save_dir = Path(".../corridor")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'dast_corridor.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(rc_dict_tmae, file)
print(f"Dictionary saved successfully to {dict_path}") 


# prepare donkeycar data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path)#, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        img = cv2.resize(img, (80, 80))#, interpolation=cv2.INTER_AREA)
        # img = img[np.newaxis, :, :]
        img = np.transpose(img, (2, 0, 1))  # HWC → CHW
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../donkeycar/normal")
abnormal = load_images_to_numpy(".../donkeycar/abnormal_dirtyrain")
normal_tensor = normal[:16000]; abnormal_tensor = abnormal[:2000]
combo = np.concatenate((normal_tensor, abnormal_tensor), axis=0)

print(combo.shape)

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)

combo2 = combo.numpy(); combo1 = combo2[:16000]

embed_num = 32
rc_dict_tmae = main.dast_recon(combo1=combo1, combo2=combo2, T=5, batchsize=512, channel=3, height=80, width=80,
               epochs=70, target="last", command="True", device="cuda")
save_dir = Path(".../donkeycar")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'dast_donkeycar.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(rc_dict_tmae, file)
print(f"Dictionary saved successfully to {dict_path}")



"""
run for roadmap
"""
# prepare ubnormal data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path)
        if img is None:
            continue
        img = np.transpose(img, (2, 0, 1))  # HWC → CHW
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../ubnormal/normal_scene_4_scenario_3")
abnormal = load_images_to_numpy(".../ubnormal/abnormal_scene_4_scenario_1_fire")
combo = np.concatenate((normal, abnormal), axis=0)
print(combo.shape)

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)

combo2 = combo.numpy(); combo1 = combo2[:451]

embed_num = 32
rc_dict_roadmap = main.roadmap_recon(combo1=combo1, combo2=combo2, T=5, batchsize=8, channel=3, 
                                height=360, width=520,
                                epochs=70, target="last", command="True", device="cuda")
save_dir = Path(".../ubnormal")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'roadmap_ubnormal.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(rc_dict_roadmap, file)
print(f"Dictionary saved successfully to {dict_path}") 


# prepare corridor data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        img = cv2.resize(img, (480, 280), interpolation=cv2.INTER_AREA)
        img = img[np.newaxis, :, :]
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../corridor/000002")
abnormal1 = load_images_to_numpy(".../corridor/000219")
abnormal2 = load_images_to_numpy(".../corridor/000222")
combo = np.concatenate((normal[:1000], abnormal1[150:250], 
                         abnormal2[-100:]), axis=0)
print(combo.shape)

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)

combo2 = combo.numpy(); combo1 = combo2[:1000]

embed_num = 32
rc_dict_roadmap = main.roadmap_recon(combo1=combo1, combo2=combo2, T=5, batchsize=8, channel=1, 
                                height=280, width=480,
                                epochs=70, target="last", command="True", device="cuda")
save_dir = Path(".../corridor")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'roadmap_corridor.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(rc_dict_roadmap, file)
print(f"Dictionary saved successfully to {dict_path}") 


# prepare donkeycar data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path)#, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        img = cv2.resize(img, (80, 80))#, interpolation=cv2.INTER_AREA)
        # img = img[np.newaxis, :, :]
        img = np.transpose(img, (2, 0, 1))  # HWC → CHW
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../donkeycar/normal")
abnormal = load_images_to_numpy(".../donkeycar/abnormal_dirtyrain")
normal_tensor = normal[:16000]; abnormal_tensor = abnormal[:2000]
combo = np.concatenate((normal_tensor, abnormal_tensor), axis=0)

print(combo.shape)

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)

combo2 = combo.numpy();combo1 = combo2[:16000]

embed_num = 32
rc_dict_roadmap = main.roadmap_recon(combo1=combo1, combo2=combo2, T=5, batchsize=64, channel=3, 
                                height=80, width=80,
                                epochs=70, target="last", command="True", device="cuda")
save_dir = Path(".../donkeycar")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'roadmap_donkeycar.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(rc_dict_roadmap, file)
print(f"Dictionary saved successfully to {dict_path}")


"""
run for tmae
"""
# prepare ubnormal data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path)
        if img is None:
            continue
        img = np.transpose(img, (2, 0, 1))  # HWC → CHW
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../ubnormal/normal_scene_4_scenario_3")
abnormal = load_images_to_numpy(".../ubnormal/abnormal_scene_4_scenario_1_fire")
combo = np.concatenate((normal, abnormal), axis=0)
print(combo.shape)

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)

combo = combo.numpy()

embed_num = 32
rc_dict_tmae = main.tmae_recon(
    combo=combo, T=6, batchsize=8,
    channel=3, height=360, width=520,
    epochs=70, mask_strategy="interval",
    embed_num=embed_num, command="True", device="cuda"
)
save_dir = Path(".../ubnormal")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'tmae_ubnormal.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(rc_dict_tmae, file)
print(f"Dictionary saved successfully to {dict_path}") 


# prepare corridor data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        img = cv2.resize(img, (480, 280), interpolation=cv2.INTER_AREA)
        img = img[np.newaxis, :, :]
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../corridor/000002")
abnormal1 = load_images_to_numpy(".../corridor/000219")
abnormal2 = load_images_to_numpy(".../corridor/000222")
combo = np.concatenate((normal[:1000], abnormal1[150:250], 
                         abnormal2[-100:]), axis=0)
print(combo.shape)

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)

combo = combo.numpy()

embed_num = 32
rc_dict_tmae = main.tmae_recon(
    combo=combo, T=6, batchsize=8,
    channel=1, height=280, width=480,
    epochs=70, mask_strategy="interval",
    embed_num=embed_num, command="True", device="cuda"
)
save_dir = Path(".../corridor")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'tmae_corridor.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(rc_dict_tmae, file)
print(f"Dictionary saved successfully to {dict_path}") 


# prepare donkeycar data
def load_images_to_numpy(folder_path):
    image_list = []
    file_list = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    for fname in tqdm(file_list, desc=f"Loading images from {os.path.basename(folder_path)}"):
        img_path = os.path.join(folder_path, fname)
        img = cv2.imread(img_path)#, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        img = cv2.resize(img, (80, 80))#, interpolation=cv2.INTER_AREA)
        # img = img[np.newaxis, :, :]
        img = np.transpose(img, (2, 0, 1))  # HWC → CHW
        image_list.append(img)
    return np.stack(image_list)

normal = load_images_to_numpy(".../donkeycar/normal")
abnormal = load_images_to_numpy(".../donkeycar/abnormal_dirtyrain")
normal_tensor = normal[:16000]; abnormal_tensor = abnormal[:2000]
combo = np.concatenate((normal_tensor, abnormal_tensor), axis=0)

print(combo.shape)

combo = torch.from_numpy(combo)
combo = combo.float()
for i in range(combo.shape[0]):
    img = combo[i]
    min_val = img.min(); max_val = img.max()
    if max_val > min_val:
        combo[i] = (img - min_val) / (max_val - min_val + 1e-6)
    else:
        combo[i] = torch.zeros_like(img)

combo = combo.numpy()

embed_num = 32
rc_dict_tmae = main.tmae_recon(
    combo=combo, T=6, batchsize=512,
    channel=3, height=80, width=80,
    epochs=70, mask_strategy="interval",
    embed_num=embed_num, command="True", device="cuda"
)
save_dir = Path(".../donkeycar")
save_dir.mkdir(parents = True, exist_ok = True) 
dict_name = 'tmae_donkeycar.pickle'
dict_path = os.path.join(save_dir, dict_name)
with open(dict_path, 'wb') as file:
    pickle.dump(rc_dict_tmae, file)
print(f"Dictionary saved successfully to {dict_path}")