# this file incorporates code from FACTOR and AVH-Align.
import argparse
import os
import librosa
import csv
import numpy as np
import torch
import torch.nn.functional as F
from python_speech_features import logfbank
from tqdm import tqdm
from torchvision import transforms

# Fix deprecation in numpy
np.float = np.float64
np.int = np.int_
import utils as avhubert_utils
from fairseq import checkpoint_utils
import cv2
import mediapipe as mp
import os
import alpha_clip
from PIL import Image

FPS = 25
#Alpha-CLIP
mask_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(0.5, 0.26)
])

model, preprocess = alpha_clip.load(
    "ViT-L/14",
    device='cpu', 
    alpha_vision_ckpt_pth="./clip_l14_grit20m_fultune_2xe.pth",
    lora_adapt=False, rank=-1
)
device = "cuda:0"
model = model.float().to(device)

def load_model(ckpt_path):
    models, _, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
    model = models[0]
    if hasattr(model, "decoder"):
        print("Checkpoint: fine-tuned")
        model = model.encoder.w2v_model
    else:
        print("Checkpoint: pre-trained w/o fine-tuning")
    model.cuda().eval()
    return model, task

def load_transforms(task):
    return avhubert_utils.Compose([
        avhubert_utils.Normalize(0.0, 255.0),
        avhubert_utils.CenterCrop((task.cfg.image_crop_size, task.cfg.image_crop_size)),
        avhubert_utils.Normalize(task.cfg.image_mean, task.cfg.image_std)
    ])
 

def load_audio(path, sample_rate=16000, stack_order_audio=4):
    wav_data, sr = librosa.load(path, sr=sample_rate)
    assert sr == sample_rate and len(wav_data.shape) == 1
    audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32)

    if len(audio_feats) % stack_order_audio != 0:
        pad = stack_order_audio - len(audio_feats) % stack_order_audio
        audio_feats = np.concatenate([audio_feats, np.zeros((pad, audio_feats.shape[1]), dtype=audio_feats.dtype)])

    audio_feats = audio_feats.reshape(-1, stack_order_audio * audio_feats.shape[1])
    audio_feats = torch.from_numpy(audio_feats.astype(np.float32))
    with torch.no_grad():
        audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:])
    return audio_feats

def frams_and_mask(video_path, MASK_DIR, FRAME_DIR):
    mp_face_mesh = mp.solutions.face_mesh

    face_mesh = mp_face_mesh.FaceMesh(
    static_image_mode=False,
    max_num_faces=1,
    refine_landmarks=True,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5
        ) 
    cap = cv2.VideoCapture(video_path)
    frame_idx = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        h, w, _ = frame.shape
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = face_mesh.process(rgb)
        
        mask = np.zeros((h, w), dtype=np.uint8)

        if results.multi_face_landmarks:
            for face_landmarks in results.multi_face_landmarks:
                points = np.array([(int(lm.x * w), int(lm.y * h)) for lm in face_landmarks.landmark], dtype=np.int32)
                hull = cv2.convexHull(points)
                cv2.fillConvexPoly(mask, hull, 255)

        cv2.imwrite(f"{MASK_DIR}/mask_{frame_idx:05d}.png", mask)
        cv2.imwrite(f"{FRAME_DIR}/frame_{frame_idx:05d}.png", frame)

        frame_idx += 1

    cap.release()
    face_mesh.close()


def AlphaClip_feature(FRAME_DIR, MASK_DIR):
    frame_files = sorted(os.listdir(FRAME_DIR))
    mask_files = sorted(os.listdir(MASK_DIR))

    assert len(frame_files) == len(mask_files), "diffrent number of frames and masks"

    gl_att = None
    count = 0

    for frame_name, mask_name in tqdm(zip(frame_files, mask_files), total=len(frame_files)):
        raw_image = Image.open(os.path.join(FRAME_DIR, frame_name)).convert("RGB")
        wb_mask_pil = Image.open(os.path.join(MASK_DIR, mask_name))
        wb_mask_image = preprocess.transforms[0](wb_mask_pil)
        wb_mask_image = preprocess.transforms[1](wb_mask_image)
        wb_mask = np.array(wb_mask_image)
        if wb_mask.ndim == 2:
            wb_mask = np.stack([wb_mask]*3, axis=-1)

        binary_mask = (wb_mask[:, :, 0] == 255)
        alpha = mask_transform((binary_mask * 255).astype(np.uint8)).cuda().unsqueeze(0)
        
        transform_cvpr = transforms.ToTensor()
        tensor_image = transform_cvpr(raw_image)
        image = preprocess(tensor_image).unsqueeze(0).to(device)

        with torch.no_grad():
            global_feature, _ = model.visual(image, alpha, return_attn=True) 

        if gl_att is None:
            gl_att = global_feature 
        else:
            gl_att += global_feature
        count += 1
 
    gl_att /= count

    return gl_att

def extract_log_mel_spectrogram(audio_path, sr=16000, n_mels=128):
    y, sr = librosa.load(audio_path, sr=sr)
    mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
    log_mel = librosa.power_to_db(mel_spectrogram, ref=np.max)
    return log_mel


def extract_features(model, video_path, audio_path, transform):
    frames = avhubert_utils.load_video(video_path)
    MASK_DIRs = "Masks"
    FRAME_DIRs = "Frames"
    os.makedirs(MASK_DIRs, exist_ok=True)
    os.makedirs(FRAME_DIRs, exist_ok=True)
    frams_and_mask(video_path, MASK_DIRs, FRAME_DIRs) 
    gl_att = AlphaClip_feature(FRAME_DIRs, MASK_DIRs)
    mel_featrue = extract_log_mel_spectrogram(audio_path)
    frames = transform(frames)
    frames = torch.FloatTensor(frames).unsqueeze(0).unsqueeze(0).cuda()
    audio = load_audio(audio_path)
    min_len = min(frames.shape[2], audio.shape[-1])
    frames, audio = frames[:, :, :min_len], audio[:, :, :min_len]

    with torch.no_grad(): 
        f_audio, _ = model.extract_finetune({"video": None, "audio": audio}, None, None)
        f_video, _ = model.extract_finetune({"video": frames, "audio": None}, None, None)
        f_mm, _ = model.extract_finetune({"video": frames, "audio": audio}, None, None)
    return f_audio.squeeze(0).cpu().numpy(), f_video.squeeze(0).cpu().numpy(), f_mm.squeeze(0).cpu().numpy(), gl_att.cpu().detach().numpy(), mel_featrue

def process_data(args, model, transform, category):
    file_paths = set()
 
    # Load metadata CSV and filter by category
    with open(args.metadata, "r") as f: 
        reader = csv.DictReader(f)
        for row in reader:
            if row["type"] == category:
                path = row["path"]
                file_paths.add(path)  

    for _, file_path in enumerate(tqdm(file_paths)):  
        mouth_roi_path = args.data_path + '/' + file_path[:-4] + "_roi.mp4"
        audio_path = args.data_path + '/' + file_path[:-4] + ".wav"

        try:
            feature_audio, feature_vid, feature_multimodal,aclip_global, mel = extract_features(model, mouth_roi_path, audio_path, transform)
        except Exception as e:
            print(f"Unprocessed for file: {mouth_roi_path}; error {e}")
            continue

        save_dict = { 
            "visual": feature_vid,
            "audio": feature_audio,
            "multimodal": feature_multimodal,
            "global": aclip_global,  
            "mel": mel
        }
        save_path = os.path.join(args.save_path, file_path.replace(".mp4", ".npz"))
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        np.savez(save_path, **save_dict)    

def main():
    parser = argparse.ArgumentParser(description="Extract features")
    parser.add_argument("--dataset", type=str, default="SHDF", help="Dataset to extract features")
    parser.add_argument("--metadata", type=str,default="./train_metadata.csv", help="Path to the dataset metadata")
    parser.add_argument("--split", default="train", help="data split to process")
    parser.add_argument("--ckpt_path", type=str, default="self_large_vox_433h.pt", help="Path to AVHubert checkpoint")
    parser.add_argument("--data_path", type=str, default="preprocessed/", help="Path to the root folder pf preprocessed data")
    parser.add_argument("--save_path", type=str, default="features/", help="Output directory for saving features")
    parser.add_argument("--category", choices=["0_real", "1_fake"], default="all", help="select category")
    args = parser.parse_args()

    # model
    model, task = load_model(args.ckpt_path)
    transform = load_transforms(task)
    categories = ['0_real', '1_fake']
    for category in categories:
        process_data(args, model, transform, category) 

if __name__ == "__main__":
    main()



