import os
import numpy as np
from PIL import Image
import joblib
import torch
import torchvision.transforms as transforms
import torchvision.models as models

def build_feature_extractor(device):
    base = models.resnet50(pretrained=True)
    in_feats = base.fc.in_features
    backbone = list(base.children())[:-1]
    feat_extractor = torch.nn.Sequential(
        *backbone,
        torch.nn.Flatten(),
        torch.nn.Linear(in_feats, 6)
    ).to(device)
    feat_extractor.eval()
    return feat_extractor

def classify_videos(input_folder, model_path):
    data = joblib.load(model_path)
    svm = data['svm']
    classes = data['classes']
    cfg = data['config']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    feat_extractor = build_feature_extractor(device)

    img_size = tuple(cfg['image_size'])
    roi = cfg.get('roi_size')
    left = int(img_size[0] // 2 - roi / 2)
    top = int(img_size[1] // 2 - roi / 2)
    right = int(img_size[0] // 2 + roi / 2)
    bottom = int(img_size[1] // 2 + roi / 2)

    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    results = []

    manikin_folder = os.path.join(input_folder, 'manikin')
    generated_folder = os.path.join(input_folder, 'generated')

    if not os.path.isdir(manikin_folder):
        print(f"Manikin folder not found: {manikin_folder}")
        return results
    if not os.path.isdir(generated_folder):
        print(f"Generated folder not found: {generated_folder}")
        return results

    for video_name in os.listdir(manikin_folder):
        manikin_vid_path = os.path.join(manikin_folder, video_name)
        generated_vid_path = os.path.join(generated_folder, video_name)
        
        if not os.path.isdir(manikin_vid_path):
            continue
        if not os.path.isdir(generated_vid_path):
            results.append({'video': video_name, 'error': 'Generated video not found'})
            continue
        
        manikin_frames = []
        for fn in os.listdir(manikin_vid_path):
            if fn.lower().endswith(('.png', '.jpg', '.jpeg')):
                manikin_frames.append(os.path.join(manikin_vid_path, fn))
        manikin_frames.sort()
        
        generated_frames = []
        for fn in os.listdir(generated_vid_path):
            if fn.lower().endswith(('.png', '.jpg', '.jpeg')):
                generated_frames.append(os.path.join(generated_vid_path, fn))
        generated_frames.sort()
        
        if not manikin_frames:
            results.append({'video': video_name, 'error': 'No frames in manikin video'})
            continue
        if not generated_frames:
            results.append({'video': video_name, 'error': 'No frames in generated video'})
            continue
        
        current_frames = manikin_frames.copy()
        replaced_count = 0
        max_replacements = len(current_frames)
        final_class = 'manikin'
        proba = None
        
        for iteration in range(max_replacements + 1):
            frame_feats = []
            for frame_path in current_frames:
                try:
                    img = Image.open(frame_path)
                    img = img.resize(img_size)
                    img = img.crop((left, top, right, bottom))
                    if cfg['grayscale']:
                        img = img.convert('L')
                        img = Image.merge('RGB', (img, img, img))
                    else:
                        img = img.convert('RGB')
                    
                    x = transform(img).unsqueeze(0).to(device)
                    with torch.no_grad():
                        feat = feat_extractor(x)
                    frame_feats.append(feat.cpu().numpy().flatten())
                except Exception as e:
                    print(f"Skipping {frame_path}: {e}")
            
            if not frame_feats:
                results.append({'video': video_name, 'error': 'No valid frames during processing'})
                break
            
            vid_feat = np.mean(frame_feats, axis=0)
            proba = svm.predict_proba([vid_feat])[0]
            idx = np.argmax(proba)
            predicted_class = classes[idx]
            
            if predicted_class == 'real':
                final_class = 'real'
                break
            
            if iteration < max_replacements:
                replace_idx = np.random.randint(0, len(current_frames))
                new_frame = np.random.choice(generated_frames)
                current_frames[replace_idx] = new_frame
                replaced_count += 1
        
        result_entry = {
            'video': video_name,
            'replaced_frames': replaced_count,
            'max_replacements': max_replacements,
            'final_class': final_class,
        }
        
        if proba is not None:
            result_entry['confidence'] = round(proba[idx], 4)
            result_entry['probabilities'] = {
                classes[0]: round(proba[0], 4),
                classes[1]: round(proba[1], 4)
            }
        
        results.append(result_entry)
    
    return results

if __name__ == "__main__":
    input_folder = 'test'
    model_path = 'video_classifier.joblib'

    results = classify_videos(input_folder, model_path)

    print("\nResults:")
    for r in results:
        if 'error' in r:
            print(f"{r['video']}: {r['error']}")
        else:
            print(f"{r['video']}:")
            print(f"  Final Class: {r['final_class']}")
            print(f"  Frames Replaced: {r['replaced_frames']}")
            if 'confidence' in r:
                print(f"  Confidence: {r['confidence']:.2%}")
                print("  Probabilities:")
                for cls, p in r['probabilities'].items():
                    print(f"    {cls}: {p:.2%}")
            print()
