#!/usr/bin/env python3
import argparse
import os
import random
import shutil
import csv
import json
import glob

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import timm
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
from open_clip import create_model_and_transforms

# -------------------------------
# PART 1: Subset Generation
# -------------------------------

def load_meta_file(filepath):
    """Load a meta file and return a list of (image_name, label) tuples."""
    with open(filepath, 'r') as f:
        lines = f.readlines()
    data = []
    for line in lines:
        parts = line.strip().split()
        if len(parts) == 2:
            data.append((parts[0], parts[1]))
    return data

def write_meta_file(filepath, data):
    """Write a meta file from a list of (image_name, label) tuples."""
    with open(filepath, 'w') as f:
        for image_name, label in data:
            f.write(f"{image_name} {label}\n")

def copy_images(image_list, source_dir, target_dir):
    """Copy each image (by name) from source_dir to target_dir."""
    os.makedirs(target_dir, exist_ok=True)
    for img_name, _ in image_list:
        src = os.path.join(source_dir, img_name)
        dst = os.path.join(target_dir, img_name)
        if os.path.exists(src):
            shutil.copy2(src, dst)
        else:
            print(f"Warning: Source image {src} not found.")

def generate_subsets(dataset_dir, out_dir, n_subsets, n_class, base_seed=None):
    """
    Create subset folders by randomly selecting n_class classes.
    Each subset folder will have subfolders: meta, train, val.
    The meta files in each subset are filtered from the original meta files.
    Also, the corresponding images are copied into the train/val folders.
    
    Args:
        dataset_dir: Root directory of the dataset
        out_dir: Output directory for subsets
        n_subsets: Number of subsets to generate
        n_class: Number of classes per subset
        base_seed: Base seed number. Each subset will use (base_seed + subset_idx) as its seed
    """
    meta_dir = os.path.join(dataset_dir, "meta")
    train_images_dir = os.path.join(dataset_dir, "train")
    val_images_dir = os.path.join(dataset_dir, "val")
    
    # Load original meta files
    train_meta_path = os.path.join(meta_dir, 'train.txt')
    val_meta_path = os.path.join(meta_dir, 'val.txt')
    train_data = load_meta_file(train_meta_path)
    val_data = load_meta_file(val_meta_path)
    
    # Get all classes (as strings) present in both files.
    all_classes = set()
    for _, label in train_data + val_data:
        all_classes.add(label)
    all_classes = list(all_classes)
    print("Total classes found:", len(all_classes))
    
    os.makedirs(out_dir, exist_ok=True)
    
    # Generate subsets
    for subset_idx in range(n_subsets):
        # Use sequential seeds starting from base_seed
        if base_seed is not None:
            current_seed = base_seed + subset_idx
            print(f"Using seed {current_seed} for subset {subset_idx}")
            random.seed(current_seed)
            np.random.seed(current_seed)
        
        chosen_classes = random.sample(all_classes, n_class)
        print(f"Subset {subset_idx}: chosen classes: {chosen_classes}")
        
        subset_train = [ (img, label) for img, label in train_data if label in chosen_classes ]
        subset_val   = [ (img, label) for img, label in val_data if label in chosen_classes ]
        
        subset_dir = os.path.join(out_dir, f"subset_{subset_idx}")
        meta_dir_subset = os.path.join(subset_dir, "meta")
        train_dir_subset = os.path.join(subset_dir, "train")
        val_dir_subset = os.path.join(subset_dir, "val")
        
        os.makedirs(meta_dir_subset, exist_ok=True)
        os.makedirs(train_dir_subset, exist_ok=True)
        os.makedirs(val_dir_subset, exist_ok=True)
        
        # Write filtered meta files for the subset
        write_meta_file(os.path.join(meta_dir_subset, 'train.txt'), subset_train)
        write_meta_file(os.path.join(meta_dir_subset, 'val.txt'), subset_val)
        print(f"Subset {subset_idx}: {len(subset_train)} train and {len(subset_val)} val entries written.")
        
        # Copy corresponding images
        copy_images(subset_train, train_images_dir, train_dir_subset)
        copy_images(subset_val, val_images_dir, val_dir_subset)
        print(f"Subset {subset_idx}: Images copied into {subset_dir}")

# -------------------------------
# PART 2: DB & Support Set Generation
# -------------------------------

def generate_csv_from_meta(subset_dir, db_dir):
    """
    Generate a CSV file (DB.csv) in db_dir from both the subset's meta/train.txt and meta/val.txt.
    The CSV will include both Train and Test images.
    Columns: Label, Zone, Path, DB, ImageName.
      - For Train images: Path is taken from subset_dir/train.
      - For Test images: Path is taken from subset_dir/val.
    """
    train_meta_file = os.path.join(subset_dir, "meta", "train.txt")
    val_meta_file = os.path.join(subset_dir, "meta", "val.txt")
    train_images_folder = os.path.join(subset_dir, "train")
    val_images_folder = os.path.join(subset_dir, "val")
    csv_file = os.path.join(db_dir, "DB.csv")
    
    with open(csv_file, 'w', newline='') as fout:
         writer = csv.writer(fout)
         writer.writerow(["Label", "Zone", "Path", "DB", "ImageName"])
         # Write train rows
         if os.path.exists(train_meta_file):
             with open(train_meta_file, 'r') as fin:
                 for line in fin:
                     parts = line.strip().split()
                     if len(parts) == 2:
                         image_name, label = parts
                         path = os.path.join(train_images_folder, image_name)
                         writer.writerow([label, "0", path, "Train", image_name])
         # Write test rows (from val)
         if os.path.exists(val_meta_file):
             with open(val_meta_file, 'r') as fin:
                 for line in fin:
                     parts = line.strip().split()
                     if len(parts) == 2:
                         image_name, label = parts
                         path = os.path.join(val_images_folder, image_name)
                         writer.writerow([label, "0", path, "Test", image_name])
    return csv_file

# A simple dataset that reads our CSV file.
class SimpleCSVDataset(Dataset):
    def __init__(self, csv_file, images_folder, split=None, transform=None):
        self.samples = []
        self.images_folder = images_folder  # not used in __getitem__; path is stored in CSV.
        self.transform = transform
        self.split = split
        with open(csv_file, newline='') as f:
            reader = csv.DictReader(f)
            for row in reader:
                if self.split is not None and row["DB"] != self.split:
                    continue
                self.samples.append(row)
        print(f"Loaded {len(self.samples)} samples from {csv_file}")
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        sample = self.samples[idx]
        image_path = sample["Path"]
        with Image.open(image_path) as im:
            image = im.convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, sample["Label"], sample["Zone"], sample["Path"], sample["ImageName"], sample["DB"]
    def get_samples(self):
        return self.samples
    def get_labels_map(self):
        labels = {}
        for s in self.samples:
            labels[s["Label"]] = s["Label"]
        return labels

def save_samples_to_csv(samples, csv_file):
    with open(csv_file, 'w', newline='') as f:
         writer = csv.DictWriter(f, fieldnames=["Label", "Zone", "Path", "DB", "ImageName"])
         writer.writeheader()
         for s in samples:
              writer.writerow(s)

def save_labels_to_csv(labels_map, csv_file):
    with open(csv_file, 'w', newline='') as f:
         writer = csv.writer(f)
         writer.writerow(["Label", "Value"])
         for k, v in labels_map.items():
              writer.writerow([k, v])

def compute_embeddings_from_loader(dataloader, model, device):
    """
    Compute embeddings and collect metadata from a given dataloader.
    Returns:
      - embeddings: a torch tensor of shape [N, D]
      - metadata: a list of tuples (Label, Zone, Path, DB, ImageName)
    """
    model.eval()
    embeddings_list = []
    metadata = []
    with torch.no_grad():
         for images, labels, zones, paths, img_file_names, dbs in tqdm(dataloader, desc="Computing embeddings"):
              images = images.to(device)
              emb = model.encode_image(images)
              embeddings_list.append(emb)
              metadata.extend(list(zip(labels, zones, paths, dbs, img_file_names)))
    embeddings = torch.cat(embeddings_list, dim=0)
    return embeddings, metadata

def select_diverse_subset(embeddings_subset, num_samples_per_class):
    N = embeddings_subset.size(0)
    if N < num_samples_per_class:
        raise ValueError("Not enough samples")
    indices = torch.randperm(N, device=embeddings_subset.device)[:num_samples_per_class]
    return indices

def generate_train_similarities(train_embeddings, train_metadata, out_folder):
    """
    Compute pairwise similarities among training images only.
    Similarity is computed as cosine similarity between normalized embeddings.
    Saves the result to 'train_similarities.json' in out_folder.
    """
    normed = F.normalize(train_embeddings, p=2, dim=1)
    sims = torch.matmul(normed, normed.t())
    sims.fill_diagonal_(-1e9)
    files = [os.path.basename(m[2]) for m in train_metadata]
    classes = [int(m[0]) for m in train_metadata]
    sim_dict = {}
    for i, f in tqdm(enumerate(files), desc="Generating train similarities"):
        sim_dict[f] = {}
        for j, (f2, c) in enumerate(zip(files, classes)):
            if i == j: continue
            sim_dict[f].setdefault(c, []).append((f2, sims[i, j].item()))
        for c in sim_dict[f]:
            sim_dict[f][c].sort(key=lambda x: x[1], reverse=True)
    with open(os.path.join(out_folder, 'train_similarities.json'), 'w') as fp:
        json.dump(sim_dict, fp, indent=2)

def generate_test_similarities(test_embeddings, test_metadata, train_embeddings, train_metadata, out_folder):
    """
    For each test query, compute similarities with all training images.
    Saves the result to 'test_similarities.json' in out_folder.
    """
    normed_test = F.normalize(test_embeddings, p=2, dim=1)
    normed_train = F.normalize(train_embeddings, p=2, dim=1)
    train_files = [os.path.basename(m[2]) for m in train_metadata]
    train_classes = [int(m[0]) for m in train_metadata]
    sim_dict = {}
    for i, test_meta in enumerate(test_metadata):
        query_image_name = test_meta[4]  # ImageName field from test
        sims = torch.matmul(normed_test[i].unsqueeze(0), normed_train.T).squeeze(0)
        sim_dict[query_image_name] = {}
        for j, (train_img, c) in enumerate(zip(train_files, train_classes)):
            sim_dict[query_image_name].setdefault(c, []).append((train_img, sims[j].item()))
        for c in sim_dict[query_image_name]:
            sim_dict[query_image_name][c].sort(key=lambda x: x[1], reverse=True)
    with open(os.path.join(out_folder, 'test_similarities.json'), 'w') as fp:
        json.dump(sim_dict, fp, indent=2)


def select_support_set(query, query_targets, support, support_targets, n):
    """
    For each query image, find the top-n most similar support images (from training images) per class.
    Both query_targets and support_targets are lists of metadata tuples:
      (Label, Zone, Path, DB, ImageName)
    """
    # Compute a set of allowed training image names from support_targets
    allowed_images = {s[4] for s in support_targets}
    
    batch_size, feature_dim = query.size(0), query.size(1)
    support_dict = {}
    # Build a dictionary mapping class to candidate indices (only from training set)
    for j, (label, zone, path, db, img_file_name) in enumerate(support_targets):
        c = int(label)
        support_dict.setdefault(c, []).append(j)
    # Convert list of indices per class to torch tensors for later indexing.
    for key in support_dict:
        support_dict[key] = torch.tensor(support_dict[key], device=support.device, dtype=torch.long)
    
    unique_classes = sorted({int(label) for label, zone, path, db, img_file_name in support_targets})
    result_dict = {}
    
    for i in tqdm(range(batch_size), desc="Processing query samples for support set"):
        query_image_name = query_targets[i][4]
        result_dict[query_image_name] = {}
        for c in unique_classes:
            # Get candidate indices for class c
            candidate_indices = support_dict.get(c, [])
            # Remove the query image from candidate indices, if present
            candidate_indices = [j for j in candidate_indices.tolist() if support_targets[j][4] != query_image_name]
            if candidate_indices:
                candidate_indices = torch.tensor(candidate_indices, device=support.device, dtype=torch.long)
            if len(candidate_indices) != 0:
                candidate_vectors = F.normalize(support[candidate_indices], p=2, dim=-1)
                query_vector = F.normalize(query[i].unsqueeze(0), p=2, dim=-1)
                sim = torch.matmul(query_vector, candidate_vectors.T).squeeze(0)
                k = min(n, sim.size(0))
                topk = torch.topk(sim, k=k, dim=0)
                topk_rel_indices = topk.indices
                if topk_rel_indices.size(0) < n:
                    rep_idx = torch.arange(n, device=support.device) % topk_rel_indices.size(0)
                    replicated_indices = topk_rel_indices[rep_idx]
                else:
                    replicated_indices = topk_rel_indices
                selected_ids = candidate_indices[replicated_indices]
                selected_support_files = [support_targets[j][4] for j in selected_ids.tolist()]
                # --- ASSERTIONS: Check that each support image is valid ---
                for support_file in selected_support_files:
                    # Check that the support image comes from the training set.
                    assert support_file in allowed_images, (
                        f"Support image {support_file} is not in the allowed training set."
                    )
                    # Ensure that the query image is not supporting itself.
                    assert support_file != query_image_name, (
                        f"Image {query_image_name} is found in its own support set."
                    )
                result_dict[query_image_name][c] = selected_support_files
            else:
                result_dict[query_image_name][c] = []
    
    # Optionally, ensure every query has support entries for all classes.
    for k in result_dict:
        assert len(result_dict[k]) == len(unique_classes), (
            f"Query {k} does not have support entries for all classes."
        )
    return result_dict


def create_support_and_similarities(args, csv_file, train_images_folder, test_images_folder, out_folder):
    """
    Compute embeddings for both Train and Test images (using the current splits).
    Then, for each query image (Train and Test), compute a support set by comparing its embedding
    against the training embeddings only. Finally, generate similarity JSON files.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, preprocess_train, preprocess_val = create_model_and_transforms(
        model_name="ViT-H/14-quickgelu", pretrained="dfn5b"
    )
    preprocess_val = transforms.Compose(
        [transforms.CenterCrop(224)] + [t for t in preprocess_val.transforms if not isinstance(t, transforms.Resize)]
    )
    transform = preprocess_val
    model.to(device)
    
    # Compute train embeddings from images with DB label "Train"
    train_dataset = SimpleCSVDataset(csv_file, train_images_folder, split="Train", transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    train_embeddings, train_metadata = compute_embeddings_from_loader(train_loader, model, device)
    
    # Compute test embeddings from images with DB label "Test"
    test_dataset = SimpleCSVDataset(csv_file, test_images_folder, split="Test", transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    test_embeddings, test_metadata = compute_embeddings_from_loader(test_loader, model, device)
    
    # For both Train and Test queries, use training embeddings as support.
    train_support = select_support_set(train_embeddings, train_metadata, train_embeddings, train_metadata, args.k)
    test_support = select_support_set(test_embeddings, test_metadata, train_embeddings, train_metadata, args.k)
    
    # Combine support sets into one dictionary.
    support_set = {
        "Train": train_support,
        "Test": test_support
    }
    support_set_file = os.path.join(out_folder, "db_support_set.json")
    with open(support_set_file, 'w') as f:
        json.dump(support_set, f, indent=2)
    print(f"Saved support set to {support_set_file}")
    
    # Generate similarity JSON files.
    generate_train_similarities(train_embeddings, train_metadata, out_folder)
    generate_test_similarities(test_embeddings, test_metadata, train_embeddings, train_metadata, out_folder)

def create_db_for_subset(subset_dir, args):
    """
    For a given subset (with its own meta, train, and val folders),
    create a 'db' folder, generate DB.csv (including Train and Test images),
    compute embeddings for both splits, and generate support set and similarity JSON files.
    """
    db_dir = os.path.join(subset_dir, "db")
    os.makedirs(db_dir, exist_ok=True)
    csv_file = generate_csv_from_meta(subset_dir, db_dir)
    print(f"Generated CSV: {csv_file}")
    train_images_folder = os.path.join(subset_dir, "train")
    test_images_folder = os.path.join(subset_dir, "val")
    args.out_folder = db_dir
    create_support_and_similarities(args, csv_file, train_images_folder, test_images_folder, db_dir)

# -------------------------------
# MAIN FUNCTION: Loop Over n_class Values
# -------------------------------

def main():
    parser = argparse.ArgumentParser(
        description="Generate subsets for multiple n_class values and create a DB with support set and similarity files for each subset."
    )
    # parser.add_argument("--dataset_dir", type=str, required=True,
    #                     help="Directory of the dataset containing meta, train, and val folders")
    # parser.add_argument("--out_base_dir", type=str, required=True,
    #                     help="Base output directory where the subsets will be saved. Subdirectories ncl_{n_class} will be created.")
    parser.add_argument("--n_subsets", type=int, default=5,
                        help="Number of subsets to generate for each n_class value")
    parser.add_argument("--k", type=int, default=5,
                        help="Number of nearest neighbors for support set")
    parser.add_argument("--batch_size", type=int, default=128,
                        help="Batch size for embedding computation")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    args = parser.parse_args()
    
    dataset_list = [
                    'caltech101', 
                    # 'caltech256',
                    # 'cifar100',
                    # 'cub200',
                    # 'food101',
                    # 'gtsrb43',
                    # 'imagenet',
                    # 'indoor67',
                    # 'textures47',
                    ]

    n_class_values = [2, 3, 4, 5, 10]

    for ds in dataset_list:
        args.dataset_dir = f'/media/ai-ubuntu/data1/projects/fca/datasets/{ds}'
        args.out_base_dir = f'./datasets_reproduced/{ds}'
        for n_class in n_class_values:
            out_dir = os.path.join(args.out_base_dir, f"ncl_{n_class}")
            print(f"\n=== Processing n_class = {n_class}, output directory: {out_dir} ===")
            # Generate subsets for this n_class value with sequential seeds starting from 0
            generate_subsets(args.dataset_dir, out_dir, args.n_subsets, n_class, base_seed=0)
            
            # For each generated subset, create the DB (CSV, embeddings, support set, similarity files).
            subset_dirs = glob.glob(os.path.join(out_dir, "subset_*"))
            for subset_dir in subset_dirs:
                print(f"\n--- Processing subset: {subset_dir} ---")
                create_db_for_subset(subset_dir, args)
        
    print("\nAll processing complete!")

if __name__ == "__main__":
    main()
