import os
import numpy as np
from PIL import Image
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import joblib

import torch
import torchvision.transforms as transforms
import torchvision.models as models

# Configuration
DATA_DIR = 'data'
IMAGE_SIZE = (256, 256)
USE_GRAYSCALE = True
TEST_SIZE = 0.2
RANDOM_STATE = 1
MODEL_SAVE_PATH = 'video_classifier.joblib'
ROI_SIZE = 100

left   = int(IMAGE_SIZE[0]//2 - ROI_SIZE/2)
top    = int(IMAGE_SIZE[1]//2 - ROI_SIZE/2)
right  = int(IMAGE_SIZE[0]//2 + ROI_SIZE/2)
bottom = int(IMAGE_SIZE[1]//2 + ROI_SIZE/2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1) Load ResNet50 and replace its final fc → 6‐unit + Softmax
base_model = models.resnet50(pretrained=True)
in_feats = base_model.fc.in_features

backbone = list(base_model.children())[:-1]  # all layers up through avgpool
feature_extractor = torch.nn.Sequential(
    *backbone,
    torch.nn.Flatten(),
    torch.nn.Linear(in_feats, 6)       # 6‐dim embedding
)

for param in feature_extractor.parameters():
    param.requires_grad = False

feature_extractor.to(device)
feature_extractor.eval()

# 2) Standard ImageNet pre‐processing
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def extract_feature(path):
    img = Image.open(path).resize(IMAGE_SIZE).crop((left, top, right, bottom))
    if USE_GRAYSCALE:
        img = img.convert('L')
        img = Image.merge('RGB', (img, img, img))
    else:
        img = img.convert('RGB')
    x = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        feat = feature_extractor(x)
    return feat.cpu().numpy().flatten()

print("Extracting 6-D Softmax-normalized features from video frames…")
video_features = []
labels = []

for class_idx, class_name in enumerate(['real', 'manikin']):
    cd = os.path.join(DATA_DIR, class_name)
    for video_folder in os.listdir(cd):
        vp = os.path.join(cd, video_folder)
        if not os.path.isdir(vp): continue
        frames = []
        for fn in os.listdir(vp):
            if fn.lower().endswith(('.png','.jpg','jpeg')):
                try:
                    fv = extract_feature(os.path.join(vp, fn))
                    frames.append(fv)
                except Exception as e:
                    print(f"Skipping {fn}: {e}")
        if frames:
            video_features.append(np.mean(frames, axis=0))
            labels.append(class_idx)

X = np.array(video_features)
y = np.array(labels)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=TEST_SIZE, stratify=y, random_state=RANDOM_STATE
)
svm = SVC(kernel='rbf', degree=5, probability=True, random_state=RANDOM_STATE)
svm.fit(X_train, y_train)

joblib.dump({
    'svm': svm,
    'classes': ['real','manikin'],
    'config': {
        'image_size': IMAGE_SIZE,
        'grayscale': USE_GRAYSCALE,
        'roi_size': ROI_SIZE
    }
}, MODEL_SAVE_PATH)

train_pred = svm.predict(X_train)
test_pred  = svm.predict(X_test)
print(f"\nTraining accuracy: {accuracy_score(y_train, train_pred):.4f}")
print(f"Test accuracy:     {accuracy_score(y_test,  test_pred):.4f}")
print(f"\nSVM saved to {MODEL_SAVE_PATH}")

