#!/usr/bin/env python3
import argparse
import os
import random
import shutil
import csv
import json
import glob

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
from open_clip import create_model_and_transforms

# -------------------------------
# PART 1: Utility Functions for Meta and Image Copying
# -------------------------------

def load_meta_file(filepath):
    """Load a meta file and return a list of (image_name, label) tuples."""
    with open(filepath, 'r') as f:
        lines = f.readlines()
    data = []
    for line in lines:
        parts = line.strip().split()
        if len(parts) == 2:
            data.append((parts[0], parts[1]))
    return data

def write_meta_file(filepath, data):
    """Write a meta file from a list of (image_name, label) tuples."""
    with open(filepath, 'w') as f:
        for image_name, label in data:
            f.write(f"{image_name} {label}\n")

def copy_images(image_list, source_dir, target_dir):
    """Copy each image (by name) from source_dir to target_dir."""
    os.makedirs(target_dir, exist_ok=True)
    for img_name, _ in image_list:
        src = os.path.join(source_dir, img_name)
        dst = os.path.join(target_dir, img_name)
        if os.path.exists(src):
            shutil.copy2(src, dst)
        else:
            print(f"Warning: Source image {src} not found.")

# -------------------------------
# PART 2: DB & Support Set Generation
# -------------------------------

def generate_csv_from_meta(subset_dir, db_dir):
    """
    Generate a CSV file (DB.csv) in db_dir from both the subset's meta/train.txt and meta/val.txt.
    The CSV will include both Train and Test images.
    Columns: Label, Zone, Path, DB, ImageName.
    """
    train_meta_file = os.path.join(subset_dir, "meta", "train.txt")
    val_meta_file = os.path.join(subset_dir, "meta", "val.txt")
    train_images_folder = os.path.join(subset_dir, "train")
    val_images_folder = os.path.join(subset_dir, "val")
    csv_file = os.path.join(db_dir, "DB.csv")
    
    with open(csv_file, 'w', newline='') as fout:
         writer = csv.writer(fout)
         writer.writerow(["Label", "Zone", "Path", "DB", "ImageName"])
         # Write train rows
         if os.path.exists(train_meta_file):
             with open(train_meta_file, 'r') as fin:
                 for line in fin:
                     parts = line.strip().split()
                     if len(parts) == 2:
                         image_name, label = parts
                         path = os.path.join(train_images_folder, image_name)
                         writer.writerow([label, "0", path, "Train", image_name])
         # Write test rows (from val)
         if os.path.exists(val_meta_file):
             with open(val_meta_file, 'r') as fin:
                 for line in fin:
                     parts = line.strip().split()
                     if len(parts) == 2:
                         image_name, label = parts
                         path = os.path.join(val_images_folder, image_name)
                         writer.writerow([label, "0", path, "Test", image_name])
    return csv_file

class SimpleCSVDataset(Dataset):
    def __init__(self, csv_file, images_folder, split=None, transform=None):
        self.samples = []
        self.images_folder = images_folder  # not used in __getitem__; path is stored in CSV.
        self.transform = transform
        self.split = split
        with open(csv_file, newline='') as f:
            reader = csv.DictReader(f)
            for row in reader:
                if self.split is not None and row["DB"] != self.split:
                    continue
                self.samples.append(row)
        print(f"Loaded {len(self.samples)} samples from {csv_file}")
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        sample = self.samples[idx]
        image_path = sample["Path"]
        with Image.open(image_path) as im:
            image = im.convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, sample["Label"], sample["Zone"], sample["Path"], sample["ImageName"], sample["DB"]
    def get_samples(self):
        return self.samples
    def get_labels_map(self):
        labels = {}
        for s in self.samples:
            labels[s["Label"]] = s["Label"]
        return labels

def save_samples_to_csv(samples, csv_file):
    with open(csv_file, 'w', newline='') as f:
         writer = csv.DictWriter(f, fieldnames=["Label", "Zone", "Path", "DB", "ImageName"])
         writer.writeheader()
         for s in samples:
              writer.writerow(s)

def save_labels_to_csv(labels_map, csv_file):
    with open(csv_file, 'w', newline='') as f:
         writer = csv.writer(f)
         writer.writerow(["Label", "Value"])
         for k, v in labels_map.items():
              writer.writerow([k, v])

def compute_embeddings_from_loader(dataloader, model, device):
    """
    Compute embeddings and collect metadata from a given dataloader.
    Returns:
      - embeddings: a torch tensor of shape [N, D]
      - metadata: a list of tuples (Label, Zone, Path, DB, ImageName)
    """
    model.eval()
    embeddings_list = []
    metadata = []
    with torch.no_grad():
         for images, labels, zones, paths, img_file_names, dbs in tqdm(dataloader, desc="Computing embeddings"):
              images = images.to(device)
              emb = model.encode_image(images)
              embeddings_list.append(emb)
              metadata.extend(list(zip(labels, zones, paths, dbs, img_file_names)))
    embeddings = torch.cat(embeddings_list, dim=0)
    return embeddings, metadata

def select_diverse_subset(embeddings_subset, num_samples_per_class):
    N = embeddings_subset.size(0)
    if N < num_samples_per_class:
        raise ValueError("Not enough samples")
    indices = torch.randperm(N, device=embeddings_subset.device)[:num_samples_per_class]
    return indices

def generate_train_similarities(train_embeddings, train_metadata, out_folder):
    """
    Compute pairwise similarities among training images only.
    Similarity is computed as cosine similarity between normalized embeddings.
    Saves the result to 'train_similarities.json' in out_folder.
    """
    normed = F.normalize(train_embeddings, p=2, dim=1)
    sims = torch.matmul(normed, normed.t())
    sims.fill_diagonal_(-1e9)
    files = [os.path.basename(m[2]) for m in train_metadata]
    classes = [int(m[0]) for m in train_metadata]
    sim_dict = {}
    for i, f in tqdm(enumerate(files), desc="Generating train similarities"):
        sim_dict[f] = {}
        for j, (f2, c) in enumerate(zip(files, classes)):
            if i == j: continue
            sim_dict[f].setdefault(c, []).append((f2, sims[i, j].item()))
        for c in sim_dict[f]:
            sim_dict[f][c].sort(key=lambda x: x[1], reverse=True)
    with open(os.path.join(out_folder, 'train_similarities.json'), 'w') as fp:
        json.dump(sim_dict, fp, indent=2)

def generate_test_similarities(test_embeddings, test_metadata, train_embeddings, train_metadata, out_folder):
    """
    For each test query, compute similarities with all training images.
    Saves the result to 'test_similarities.json' in out_folder.
    """
    normed_test = F.normalize(test_embeddings, p=2, dim=1)
    normed_train = F.normalize(train_embeddings, p=2, dim=1)
    train_files = [os.path.basename(m[2]) for m in train_metadata]
    train_classes = [int(m[0]) for m in train_metadata]
    sim_dict = {}
    for i, test_meta in enumerate(test_metadata):
        query_image_name = test_meta[4]  # ImageName field from test
        sims = torch.matmul(normed_test[i].unsqueeze(0), normed_train.T).squeeze(0)
        sim_dict[query_image_name] = {}
        for j, (train_img, c) in enumerate(zip(train_files, train_classes)):
            sim_dict[query_image_name].setdefault(c, []).append((train_img, sims[j].item()))
        for c in sim_dict[query_image_name]:
            sim_dict[query_image_name][c].sort(key=lambda x: x[1], reverse=True)
    with open(os.path.join(out_folder, 'test_similarities.json'), 'w') as fp:
        json.dump(sim_dict, fp, indent=2)

def select_support_set(query, query_targets, support, support_targets, n):
    """
    For each query image, find the top-n most similar support images (from training images) per class.
    Both query_targets and support_targets are lists of metadata tuples:
      (Label, Zone, Path, DB, ImageName)
    """
    # Compute a set of allowed training image names from support_targets
    allowed_images = {s[4] for s in support_targets}
    
    batch_size, feature_dim = query.size(0), query.size(1)
    support_dict = {}
    # Build a dictionary mapping class to candidate indices (only from training set)
    for j, (label, zone, path, db, img_file_name) in enumerate(support_targets):
        c = int(label)
        support_dict.setdefault(c, []).append(j)
    # Convert list of indices per class to torch tensors for later indexing.
    for key in support_dict:
        support_dict[key] = torch.tensor(support_dict[key], device=support.device, dtype=torch.long)
    
    unique_classes = sorted({int(label) for label, zone, path, db, img_file_name in support_targets})
    result_dict = {}
    
    for i in tqdm(range(batch_size), desc="Processing query samples for support set"):
        query_image_name = query_targets[i][4]
        result_dict[query_image_name] = {}
        for c in unique_classes:
            # Get candidate indices for class c
            candidate_indices = support_dict.get(c, [])
            # Remove the query image from candidate indices, if present
            candidate_indices = [j for j in candidate_indices.tolist() if support_targets[j][4] != query_image_name]
            if candidate_indices:
                candidate_indices = torch.tensor(candidate_indices, device=support.device, dtype=torch.long)
            if len(candidate_indices) != 0:
                candidate_vectors = F.normalize(support[candidate_indices], p=2, dim=-1)
                query_vector = F.normalize(query[i].unsqueeze(0), p=2, dim=-1)
                sim = torch.matmul(query_vector, candidate_vectors.T).squeeze(0)
                k = min(n, sim.size(0))
                topk = torch.topk(sim, k=k, dim=0)
                topk_rel_indices = topk.indices
                if topk_rel_indices.size(0) < n:
                    rep_idx = torch.arange(n, device=support.device) % topk_rel_indices.size(0)
                    replicated_indices = topk_rel_indices[rep_idx]
                else:
                    replicated_indices = topk_rel_indices
                selected_ids = candidate_indices[replicated_indices]
                selected_support_files = [support_targets[j][4] for j in selected_ids.tolist()]
                # --- ASSERTIONS: Check that each support image is valid ---
                for support_file in selected_support_files:
                    # Check that the support image comes from the training set.
                    assert support_file in allowed_images, (
                        f"Support image {support_file} is not in the allowed training set."
                    )
                    # Ensure that the query image is not supporting itself.
                    assert support_file != query_image_name, (
                        f"Image {query_image_name} is found in its own support set."
                    )
                result_dict[query_image_name][c] = selected_support_files
            else:
                result_dict[query_image_name][c] = []
    
    # Optionally, ensure every query has support entries for all classes.
    for k in result_dict:
        assert len(result_dict[k]) == len(unique_classes), (
            f"Query {k} does not have support entries for all classes."
        )
    return result_dict

def create_support_and_similarities(args, csv_file, train_images_folder, test_images_folder, out_folder):
    """
    Compute embeddings for both Train and Test images (using the current splits).
    Then, for each query image (Train and Test), compute a support set by comparing its embedding
    against the training embeddings only. Finally, generate similarity JSON files.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, preprocess_train, preprocess_val = create_model_and_transforms(
        model_name="ViT-H/14-quickgelu", pretrained="dfn5b"
    )
    preprocess_val = transforms.Compose(
        [transforms.Resize((224, 224))] + [t for t in preprocess_val.transforms if not isinstance(t, transforms.Resize)]
    )
    transform = preprocess_val
    model.to(device)
    
    # Compute train embeddings from images with DB label "Train"
    train_dataset = SimpleCSVDataset(csv_file, train_images_folder, split="Train", transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    train_embeddings, train_metadata = compute_embeddings_from_loader(train_loader, model, device)
    
    # Compute test embeddings from images with DB label "Test"
    test_dataset = SimpleCSVDataset(csv_file, test_images_folder, split="Test", transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    test_embeddings, test_metadata = compute_embeddings_from_loader(test_loader, model, device)
    
    # For both Train and Test queries, use training embeddings as support.
    train_support = select_support_set(train_embeddings, train_metadata, train_embeddings, train_metadata, args.k)
    test_support = select_support_set(test_embeddings, test_metadata, train_embeddings, train_metadata, args.k)
    
    # Combine support sets into one dictionary.
    support_set = {
        "Train": train_support,
        "Test": test_support
    }
    support_set_file = os.path.join(out_folder, "db_support_set.json")
    with open(support_set_file, 'w') as f:
        json.dump(support_set, f, indent=2)
    print(f"Saved support set to {support_set_file}")
    
    # Generate similarity JSON files.
    generate_train_similarities(train_embeddings, train_metadata, out_folder)
    generate_test_similarities(test_embeddings, test_metadata, train_embeddings, train_metadata, out_folder)

def create_db_for_subset(subset_dir, args):
    """
    For a given subset (with its own meta, train, and val folders),
    create a 'db' folder, generate DB.csv (including Train and Test images),
    compute embeddings for both splits, and generate support set and similarity JSON files.
    """
    db_dir = os.path.join(subset_dir, "db")
    os.makedirs(db_dir, exist_ok=True)
    csv_file = generate_csv_from_meta(subset_dir, db_dir)
    print(f"Generated CSV: {csv_file}")
    train_images_folder = os.path.join(subset_dir, "train")
    test_images_folder = os.path.join(subset_dir, "val")
    args.out_folder = db_dir
    create_support_and_similarities(args, csv_file, train_images_folder, test_images_folder, db_dir)

# -------------------------------
# PART 3: Creating Subset Folders from Existing Meta Files
# -------------------------------

def create_subset_from_existing_meta(dataset_dir, out_dir, n_class, subset_idx, args):
    """
    Create a subset folder by using pre-created meta files with naming pattern:
      {split}_ncls_{n_class}_s_{subset_idx}.txt
    from dataset_dir/meta.
    
    This function will:
      1. Create subset folder (with meta, train, val subfolders) under out_dir.
      2. Load the provided meta files and write them as train.txt and val.txt in the subset/meta folder.
      3. Copy the corresponding images from dataset_dir/train and dataset_dir/val into the subset folder.
      4. Call create_db_for_subset() to generate CSV, compute embeddings and support set.
    """
    subset_dir = os.path.join(out_dir, f"subset_{subset_idx}")
    os.makedirs(subset_dir, exist_ok=True)
    
    meta_dir_subset = os.path.join(subset_dir, "meta")
    train_dir_subset = os.path.join(subset_dir, "train")
    val_dir_subset = os.path.join(subset_dir, "val")
    
    os.makedirs(meta_dir_subset, exist_ok=True)
    os.makedirs(train_dir_subset, exist_ok=True)
    os.makedirs(val_dir_subset, exist_ok=True)
    
    # Build source meta filenames
    train_meta_src = os.path.join(dataset_dir, "meta", f"train_ncls_{n_class}_s_{subset_idx}.txt")
    val_meta_src = os.path.join(dataset_dir, "meta", f"val_ncls_{n_class}_s_{subset_idx}.txt")
    
    if not os.path.exists(train_meta_src) or not os.path.exists(val_meta_src):
        print(f"Meta files for n_class {n_class} subset {subset_idx} not found in {os.path.join(dataset_dir, 'meta')}. Skipping.")
        return
    
    # Load the meta data from the pre-created files
    train_data = load_meta_file(train_meta_src)
    val_data = load_meta_file(val_meta_src)
    
    # Write the meta data into the subset's meta folder as train.txt and val.txt
    write_meta_file(os.path.join(meta_dir_subset, "train.txt"), train_data)
    write_meta_file(os.path.join(meta_dir_subset, "val.txt"), val_data)
    print(f"Subset n_class {n_class} s {subset_idx}: {len(train_data)} train and {len(val_data)} val entries written.")
    
    # Copy corresponding images to the subset folders.
    train_images_dir = os.path.join(dataset_dir, "train")
    val_images_dir = os.path.join(dataset_dir, "val")
    copy_images(train_data, train_images_dir, train_dir_subset)
    copy_images(val_data, val_images_dir, val_dir_subset)
    print(f"Subset n_class {n_class} s {subset_idx}: Images copied into {subset_dir}")
    
    # Create the DB files (CSV, embeddings, support set, similarity files) for this subset.
    create_db_for_subset(subset_dir, args)

# -------------------------------
# PART 4: MAIN FUNCTION (Modified)
# -------------------------------

def main():
    parser = argparse.ArgumentParser(
        description="Create subsets from pre-created meta files and then generate a DB (CSV, embeddings, support set, similarity files) for each subset."
    )
    parser.add_argument("--n_subsets", type=int, default=5,
                        help="Number of subsets (s indices) to process per n_class value")
    parser.add_argument("--k", type=int, default=5,
                        help="Number of nearest neighbors for support set")
    parser.add_argument("--batch_size", type=int, default=128,
                        help="Batch size for embedding computation")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    args = parser.parse_args()
    
    # List of dataset names (folder names)
    dataset_list = [
        # 'caltech101', 
        # 'caltech256',
        # 'cifar100',
        # 'cub200',
        # 'food101',
        # 'gtsrb43',
        # 'imagenet',
        # 'indoor67',
        'textures47',
        'quickdraw345',
    ]
    
    # Range of n_class values to use, e.g. 2, 3, 4, 5, 10.
    n_class_values = [2, 3, 4, 5, 10]
    
    # Loop over datasets and n_class values.
    for ds in dataset_list:
        args.dataset_dir = f'/media/ai-ubuntu/data1/projects/fca/datasets/{ds}'
        args.out_base_dir = f'./dataset_ds_10_ncls_2_3_4_5_10/{ds}'
        for n_class in n_class_values:
            out_dir = os.path.join(args.out_base_dir, f"ncl_{n_class}")
            os.makedirs(out_dir, exist_ok=True)
            for subset_idx in range(args.n_subsets):
                print(f"\n=== Processing dataset: {ds}, n_class = {n_class}, subset index: {subset_idx} ===")
                create_subset_from_existing_meta(args.dataset_dir, out_dir, n_class, subset_idx, args)
        
    print("\nAll processing complete!")

if __name__ == "__main__":
    main()
