import argparse
import torch
from tqdm import tqdm
import numpy as np
from sklearn.metrics import average_precision_score, roc_auc_score
import pandas as pd
import os
import sys
from model import TAVFD


def process_video(data, fusion_model, device):
    visual_tensor = torch.from_numpy(data["visual"]).to(device)
    audio_tensor = torch.from_numpy(data["audio"]).to(device)
    gl_tensor = torch.from_numpy(data["global"]).to(device)
    # L2 norm
    visual_tensor = visual_tensor / (torch.linalg.norm(visual_tensor, ord=2, dim=-1, keepdim=True))
    audio_tensor = audio_tensor / (torch.linalg.norm(audio_tensor, ord=2, dim=-1, keepdim=True))

    output, _, _, _ = fusion_model(visual_tensor, audio_tensor, gl_tensor)
    output  = output.squeeze(0)
    score = torch.logsumexp(-output, dim=0).detach().cpu().squeeze()
    return score

def main(args):
    import datetime 
    log_file_path = os.path.join(
    args.txtpath,
    f"eval_log_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")
    class Logger(object):
        def __init__(self, filename):
            self.terminal = sys.stdout
            self.log = open(filename, "w", encoding="utf-8")

        def write(self, message):
            self.terminal.write(message)
            self.log.write(message)

        def flush(self):
            self.terminal.flush()
            self.log.flush()

    sys.stdout = Logger(log_file_path)
    # ========================================

    print(f"Evaluating T-AVFD on {args.dataset} with pretrained weights saved at {args.checkpoint_path} ...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    fusion_model_weights = torch.load(args.checkpoint_path, weights_only=False)

    fusion_model = TAVFD(tm_weights=args.tm_weights).to(device)
    fusion_model.load_state_dict(fusion_model_weights["state_dict"])
    fusion_model.eval()
    
    outputs = []
    ground_truths = []
    metadata = pd.read_csv("./test.csv") 
    for _, row in tqdm(metadata.iterrows()):
        npz_path = row['path']
        data = np.load(npz_path, allow_pickle=True)
        label = row["label"]
        score = process_video(data, fusion_model, device)
        outputs.append(score)
        ground_truths.append(label)

    outputs = np.array(outputs)
    ground_truths = np.array(ground_truths)

    auc = roc_auc_score(ground_truths, outputs)
    ap = average_precision_score(ground_truths, outputs)

    print(f"AP: {ap}")
    print(f"AUC: {auc}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate Fusion Model on Deepfake Dataset")

    parser.add_argument("--checkpoint_path", type=str, default="checkpoints/T-AVFD.pt",
                        help="Path to the pretrained fusion model checkpoint.")
    parser.add_argument("--features_path", type=str,
                        default=f"/val/",
                        help="Path to the root folder of test data.")
    parser.add_argument("--metadata", type=str,
                        default="/test_metadata.csv",
                        help="CSV file containing ground truth labels.")
    parser.add_argument("--dataset", type=str, default="SHDF",
                        help="Dataset name")
    parser.add_argument("--txtpath", type=str, default="./",
                        help="txtpath name")
    parser.add_argument('--tm_weights', type=float, nargs=3, default=[0.1, 0.1, -0.1],
                        help='Three modulation vectors, e.g. --tm_weights 0.1 0.1 -0.1')

    args = parser.parse_args()
    main(args)