import os
import sys
import shutil
from datetime import datetime

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

from absl import app, flags
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm


FLAGS = flags.FLAGS

# /   
flags.DEFINE_string("input_dir", None, help="    ")
flags.DEFINE_string("output_dir", None, help="  (: input_dir/classwise_images_rec)")
flags.DEFINE_string("dataset_name", "cifar10", help="  [cifar10, cifar100, cifar10_lt, cifar100_lt]")
flags.DEFINE_string("device", "cuda:0", help=" [cuda:0, cuda:1, cpu]")
flags.DEFINE_integer("batch_size", 256, help="  ")
flags.DEFINE_integer("num_workers", 4, help="DataLoader num_workers")
flags.DEFINE_bool("copy_files", True, help="True , False ")

#python reclassify_images.py --input_dir results_cifar10_lt/sinkhorn_otwfm_cifar10_lt_reg0.05+tauinf1.0_inv_tnu^10.0_fixsrc/classwise_images --output_dir results_cifar10_lt/sinkhorn_otwfm_cifar10_lt_reg0.05+tauinf1.0_inv_tnu^10.0_fixsrc/classwise_images_rec --dataset_name cifar10_lt --device cuda:0 --batch_size 256 --num_workers 4 --copy_files True


def _collect_image_paths(root_dir: str):
    exts = (".png", ".jpg", ".jpeg", ".bmp", ".tiff")
    paths = []
    for r, _dirs, files in os.walk(root_dir):
        for nm in files:
            if nm.lower().endswith(exts):
                paths.append(os.path.join(r, nm))
    paths.sort()
    return paths


class RecursiveImageDataset(Dataset):
    def __init__(self, image_paths, transform=None, device=None):
        self.image_paths = image_paths
        self.transform = transform
        self.device = device

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        p = self.image_paths[idx]
        img = Image.open(p).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, p


def _load_classifier(dataset_name: str, device: torch.device):
    # torch.hub   (utils_cifar.classify_generated_images  )
    if dataset_name in ["cifar10", "cifar10_lt"]:
        model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_repvgg_a2", pretrained=True)
        num_classes = 10
    elif dataset_name in ["cifar100", "cifar100_lt"]:
        model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_repvgg_a2", pretrained=True)
        num_classes = 100
    else:
        raise ValueError(f"Unsupported dataset_name: {dataset_name}")
    model = model.to(device)
    model.eval()
    return model, num_classes


def _class_names(dataset_name: str, num_classes: int):
    if num_classes == 10:
        return ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    # cifar100      
    return [f'class_{i}' for i in range(num_classes)]


def main(_):
    if FLAGS.input_dir is None:
        raise ValueError("--input_dir  .")

    use_cuda = torch.cuda.is_available() and FLAGS.device.startswith("cuda")
    device = torch.device(FLAGS.device if use_cuda else "cpu")

    input_dir = os.path.abspath(FLAGS.input_dir)
    output_dir = FLAGS.output_dir or os.path.join(input_dir, "classwise_images_rec")
    os.makedirs(output_dir, exist_ok=True)

    #  
    log_path = os.path.join(output_dir, "reclassify_log.txt")
    with open(log_path, "a") as f:
        f.write("===== Reclassification =====\n")
        f.write(f"input_dir: {input_dir}\n")
        f.write(f"output_dir: {output_dir}\n")
        f.write(f"dataset_name: {FLAGS.dataset_name}\n")
        f.write(f"device: {device}\n")
        f.write(f"batch_size: {FLAGS.batch_size}\n")
        f.write(f"num_workers: {FLAGS.num_workers}\n")
        f.write(f"copy_files: {FLAGS.copy_files}\n")
        f.write(f"start_time: {datetime.now()}\n\n")

    #   
    model, num_classes = _load_classifier(FLAGS.dataset_name, device)
    class_names = _class_names(FLAGS.dataset_name, num_classes)

    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    #   
    image_paths = _collect_image_paths(input_dir)
    if len(image_paths) == 0:
        print("  .")
        return
    print(f"Found {len(image_paths)} images under {input_dir}")

    dataset = RecursiveImageDataset(image_paths, transform=transform)
    loader = DataLoader(dataset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers)

    #   
    for i in range(num_classes):
        subdir_name = f"class_{i:02d}_{class_names[i] if i < len(class_names) else f'class_{i}'}"
        os.makedirs(os.path.join(output_dir, subdir_name), exist_ok=True)

    class_counts = [0] * num_classes
    mapping_csv = os.path.join(output_dir, "reclassify_mapping.csv")
    counts_csv = os.path.join(output_dir, "classwise_counts.csv")

    #  CSV  
    with open(mapping_csv, "w") as f:
        f.write("index,src_path,class_id,class_name,confidence,dst_path\n")

    idx_global = 0
    with torch.no_grad():
        for imgs, paths in tqdm(loader, ncols=80, desc="Classifying"):
            imgs = imgs.to(device)
            outputs = model(imgs)
            probs = torch.softmax(outputs, dim=1)
            confs, preds = torch.max(probs, dim=1)

            preds_np = preds.cpu().numpy().tolist()
            confs_np = confs.cpu().numpy().tolist()

            for pth, pred, conf in zip(paths, preds_np, confs_np):
                class_dir_name = f"class_{pred:02d}_{class_names[pred] if pred < len(class_names) else f'class_{pred}'}"
                class_dir = os.path.join(output_dir, class_dir_name)

                base = os.path.basename(pth)
                stem, ext = os.path.splitext(base)
                conf_str = f"conf{conf:.4f}"
                dst_name = f"{stem}_{conf_str}{ext}"
                dst_path = os.path.join(class_dir, dst_name)

                #   
                attempt = 0
                while os.path.exists(dst_path):
                    attempt += 1
                    dst_name = f"{stem}_{conf_str}_{attempt}{ext}"
                    dst_path = os.path.join(class_dir, dst_name)

                if FLAGS.copy_files:
                    shutil.copy2(pth, dst_path)
                else:
                    try:
                        shutil.move(pth, dst_path)
                    except Exception:
                        #     
                        shutil.copy2(pth, dst_path)

                class_counts[pred] += 1

                with open(mapping_csv, "a") as f:
                    class_name = class_names[pred] if pred < len(class_names) else f"class_{pred}"
                    f.write(f"{idx_global},{pth},{pred},{class_name},{conf:.6f},{dst_path}\n")
                idx_global += 1

    #  CSV 
    import csv
    with open(counts_csv, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['class_id', 'class_name', 'count'])
        for i in range(num_classes):
            cname = class_names[i] if i < len(class_names) else f'class_{i}'
            writer.writerow([i, cname, class_counts[i]])

    with open(log_path, "a") as f:
        f.write(f"end_time: {datetime.now()}\n")
        f.write(f"total_images: {len(image_paths)}\n")
        f.write(f"output: {output_dir}\n")

    print(f"Reclassification done. Output: {output_dir}")
    print(f"Mapping CSV: {mapping_csv}")
    print(f"Counts CSV: {counts_csv}")


if __name__ == "__main__":
    app.run(main)


