import os
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from .util import read_csv
from torch.utils.data import DataLoader, Dataset


class ReidDataset(Dataset):
    def __init__(self, paths, cids, transform):
        super(ReidDataset, self).__init__()
        self.paths = paths
        self.cids = cids
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = self.transform(Image.open(path))
        cid = self.cids[index]
        return img, cid


def extract_features(model, preprocessor, csv_path, data_dir_path, batch_size=256):
    rows = read_csv(csv_path)[1:]  # Exclude the column names
    paths = []
    pids = []
    cids = []
    tids = []
    for path_sub, pid, cid, tid in rows:
        if pid == "":
            pid = -1
        if cid == "":
            cid = -1
        if tid == "":
            tid = -1
        path = os.path.join(data_dir_path, path_sub)
        paths.append(path)
        pids.append(pid)
        cids.append(cid)
        tids.append(tid)
    pids = np.array(pids, dtype=int)
    cids = np.array(cids, dtype=int)
    tids = np.array(tids, dtype=int)
    features = extract_features_from_paths(model, preprocessor, paths, cids, batch_size)
    paths = np.array(paths)
    return features, pids, cids, tids, paths


@torch.no_grad()
def extract_features_from_paths(model, preprocessor, img_paths, cids, batch_size=256, num_workers=12):
    dataset = ReidDataset(img_paths, cids, preprocessor)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=False,
    )

    model.eval()
    features = []
    for img, cid in tqdm(loader, desc="Extract features"):
        img = img.cuda()
        cid = cid.cuda()
        out = model(img, cam_label=cid).cpu()
        features.append(out)
    features = torch.concat(features)
    return features
