# this file incorporates code from FACTOR and AVH-Align.

import argparse
import os

import librosa
import csv
import numpy as np
import torch 
import torch.nn.functional as F
from python_speech_features import logfbank
from tqdm import tqdm
import subprocess
from dino_base import DINOBase
from torchvision import transforms

# Fix deprecation in numpy
np.float = np.float64
np.int = np.int_
import math
import hubert_pretraining, hubert, hubert_asr
import utils as avhubert_utils
from fairseq import checkpoint_utils

FPS = 25

patch_size = 14  
input_size = 518
dino_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((input_size, input_size)),  
    transforms.CenterCrop((input_size, input_size))  
    ])

def load_model(ckpt_path):
    models, _, task = checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
    model = models[0]
    if hasattr(model, "decoder"):
        print("Checkpoint: fine-tuned")
        model = model.encoder.w2v_model
    else:
        print("Checkpoint: pre-trained w/o fine-tuning")
    model.cuda().eval()
    return model, task

def load_transforms(task):
    return avhubert_utils.Compose([
        avhubert_utils.Normalize(0.0, 255.0),
        avhubert_utils.CenterCrop((task.cfg.image_crop_size, task.cfg.image_crop_size)),
        avhubert_utils.Normalize(task.cfg.image_mean, task.cfg.image_std)
    ])

def load_audio(path, sample_rate=16000, stack_order_audio=4):
    wav_data, sr = librosa.load(path, sr=sample_rate)
    assert sr == sample_rate and len(wav_data.shape) == 1
    audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32)

    if len(audio_feats) % stack_order_audio != 0:
        pad = stack_order_audio - len(audio_feats) % stack_order_audio
        audio_feats = np.concatenate([audio_feats, np.zeros((pad, audio_feats.shape[1]), dtype=audio_feats.dtype)])

    audio_feats = audio_feats.reshape(-1, stack_order_audio * audio_feats.shape[1])
    audio_feats = torch.from_numpy(audio_feats.astype(np.float32))
    with torch.no_grad():
        audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:])
    return audio_feats

import cv2
import numpy as np

def compute_hsv_red_mask(frame):
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    
    lower_red1 = np.array([0, 50, 50])
    upper_red1 = np.array([10, 255, 255])
    lower_red2 = np.array([170, 50, 50])
    upper_red2 = np.array([180, 255, 255])
    
    mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
    mask2 = cv2.inRange(hsv, lower_red2, upper_red2)

    return cv2.bitwise_or(mask1, mask2)

def process_video_red_mask(video_path):

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError(f"error: {video_path}")
    
    frame_count = 0
    avg_mask = None
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        red_mask = compute_hsv_red_mask(frame)
        
        red_mask_float = red_mask.astype(np.float32) / 255.0
        
        if avg_mask is None:
            avg_mask = np.zeros_like(red_mask_float)
        avg_mask += red_mask_float
        frame_count += 1
    
    cap.release()
    
    if frame_count == 0:
        raise ValueError("no video!")
    
    avg_mask /= frame_count

    
    return avg_mask, frame_count



def extract_features(model, video_path, audio_path, transform):
    frames = avhubert_utils.load_video(video_path)
    frames = transform(frames)
    frames = torch.FloatTensor(frames).unsqueeze(0).unsqueeze(0).cuda()
    audio = load_audio(audio_path)[None, :, :].transpose(1, 2).cuda()

    min_len = min(frames.shape[2], audio.shape[-1])
    frames, audio = frames[:, :, :min_len], audio[:, :, :min_len]

    with torch.no_grad():
        f_audio, _ = model.extract_finetune({"video": None, "audio": audio}, None, None)
        f_video, _ = model.extract_finetune({"video": frames, "audio": None}, None, None)
        f_mm, _ = model.extract_finetune({"video": frames, "audio": audio}, None, None)

    try:
        avg_mask_, frame_count = process_video_red_mask(
            video_path
        )
        #DinoV2
        base_model = DINOBase(output_dim=f_video.shape[1]).to('cuda')
        for param in base_model.dino_model.parameters():
            param.requires_grad = False

        output_size = int(math.sqrt(1024)) 
        avg_mask = np.repeat(avg_mask_[np.newaxis, :, :], 3, axis=0) 
        avg_mask = torch.from_numpy(avg_mask)
        baby = avg_mask.permute(1, 2, 0)
        baby = dino_transform(baby.cpu().numpy().astype(np.uint8)).cuda() 
        f_feature0, f_feature1 = base_model(baby.unsqueeze(0), output_size=output_size)   
        local_feature = f_feature0.view(1, f_video.shape[1], 1024)  
    except Exception as e:
        print(f"error: {str(e)}")
    return f_audio.squeeze(0).cpu().numpy(), f_video.squeeze(0).cpu().numpy(), f_mm.squeeze(0).cpu().numpy(), local_feature.squeeze(0).cpu().detach().numpy(), avg_mask_, f_feature1.cpu().detach().numpy()

def process_avlips(args, model, transform, category):
    file_paths = set()

    # Load metadata CSV and filter by category
    with open(args.metadata, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            if row["type"] == category:
                path = os.path.join(row["path"].replace("*/Test/", ""), row["filename"]) 
                file_paths.add(path)

    for _, file_path in enumerate(tqdm(file_paths)): 
        mouth_roi_path = args.data_path + '/' + file_path[:-4] + "_roi.mp4"
        audio_path = args.data_path + '/' + file_path[:-4] + ".wav"

        try:
            feature_audio, feature_vid, feature_multimodal, feature_local, hsv, feature_global = extract_features(model, mouth_roi_path, audio_path, transform)

        except Exception as e:
            print(f"Unprocessed for file: {mouth_roi_path}; error {e}")
            continue

        save_dict = { 
            "visual": feature_vid,
            "audio": feature_audio,
            "multimodal": feature_multimodal,
            "local": feature_local,  # Add local feature
            "hsv": hsv,  # Add hsv feature
            "global": feature_global,  # Add global feature
        }
        save_path = os.path.join(args.save_path, file_path.replace(".mp4", ".npz"))
        os.makedirs(os.path.dirname(save_path), exist_ok=True) 

        np.savez(save_path, **save_dict)    

def main():
    parser = argparse.ArgumentParser(description="Extract AVHubert features")
    parser.add_argument("--dataset", type=str, default="AVLips", help="Dataset to extract features for")
    parser.add_argument("--metadata", type=str,default="AVLips/train_metadata.csv", help="Path to the dataset metadata")
    parser.add_argument("--split", default="train", help="data split to process (e.g., val, train)")
    parser.add_argument("--ckpt_path", type=str, default="self_large_vox_433h.pt", help="Path to AVHubert checkpoint")
    parser.add_argument("--data_path", type=str, default="avlips_preprocessed/", help="Path to the root folder pf preprocessed data")
    parser.add_argument("--save_path", type=str, default="avlips_features/", help="Output directory for saving features")
    parser.add_argument("--category", choices=["RealVideo-RealAudio", "RealVideo-FakeAudio", "FakeVideo-RealAudio", "FakeVideo-FakeAudio", "all"], default="all", help="select category")
    args = parser.parse_args()

    # model
    model, task = load_model(args.ckpt_path)
    transform = load_transforms(task)

    if args.dataset == "AVLips":
        if args.category == 'all':
            categories = ['0_real', '1_fake']
        elif args.category:
            categories = [args.category]

        for category in categories:
            process_avlips(args, model, transform, category)

if __name__ == "__main__":
    main()
