import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
import transformers

from collections import defaultdict
import sys
import numpy as np
from PIL import Image
import random
import os
import glob
import json
from sklearn.model_selection import train_test_split

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Starting index for test samples.')
flags.DEFINE_integer('teste', 1400, 'Ending index for test samples.')
flags.DEFINE_float('betaa', 0.1, 'Dirichlet distribution parameter.')
flags.DEFINE_integer('maxsam', 16, 'Maximum samples for a class.')
flags.DEFINE_integer('tek', 1, 'Number of training epochs.')

flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate for optimizer.')
flags.DEFINE_float('alpha', 0.25, 'Alpha parameter for focal loss.')
flags.DEFINE_float('gamma', 1.0, 'Gamma parameter for focal loss.')
flags.DEFINE_string('sun_path', '', 'Path to the SUN dataset directory.')
flags.DEFINE_integer('num_classes', 397, 'Number of SUN scene categories to use.')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

def load_sun_dataset(data_path, num_classes=None):
    """
    Load the SUN dataset from the specified path.
    
    Args:
        data_path: Path to the SUN dataset
        num_classes: Number of classes to use (will take the top n by frequency)
    
    Returns:
        train_data: List of (image, label) pairs for training
        test_data: List of (image, label) pairs for testing
        label_names: List of class names
    """
    print(f"Looking for SUN dataset in: {data_path}")
    possible_patterns = [
        os.path.join(data_path, "SUN397", "*", "*.jpg"),
        os.path.join(data_path, "SUN397", "*", "*", "*.jpg"),
        os.path.join(data_path, "*", "*.jpg"),
        os.path.join(data_path, "*", "*", "*.jpg"),
        os.path.join(data_path, "**", "*.jpg"),
    ]
    
    image_paths = []
    for pattern in possible_patterns:
        paths = glob.glob(pattern, recursive=True)
        if paths:
            image_paths = paths
            print(f"Found {len(paths)} images using pattern: {pattern}")
            break
    
    if not image_paths:
        for ext in ['*.png', '*.jpeg', '*.JPEG', '*.JPG', '*.PNG']:
            for pattern_base in possible_patterns:
                pattern = pattern_base.replace('*.jpg', ext)
                paths = glob.glob(pattern, recursive=True)
                if paths:
                    image_paths.extend(paths)
        
    if not image_paths:
        raise ValueError(f"No images found in {data_path}. Please check the path and directory structure.")
    
    print(f"Found {len(image_paths)} total images")
    class_counts = defaultdict(int)
    
    for path in image_paths:
        parts = path.replace(data_path, '').strip(os.sep).split(os.sep)
        if len(parts) < 2:
            continue
        if 'SUN397' in parts:
            sun_idx = parts.index('SUN397')
            if sun_idx + 1 < len(parts):
                class_name = parts[sun_idx + 1]
            else:
                continue
        else:
            class_name = parts[-2] if len(parts) > 1 else parts[0]
        class_name = class_name.replace('_', ' ').replace('-', ' ').strip()
        
        if class_name:  # Only count non-empty class names
            class_counts[class_name] += 1
    
    if not class_counts:
        raise ValueError("No valid class labels could be extracted from the directory structure.")
    
    print(f"Found {len(class_counts)} classes")
    print("Top 10 classes by frequency:")
    for cls, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
        print(f"  {cls}: {count} images")

    sorted_classes = sorted(class_counts.items(), key=lambda x: x[1], reverse=True)
    sorted_classes = [(cls, count) for cls, count in sorted_classes if count >= 4]
    
    if num_classes is not None:
        sorted_classes = sorted_classes[:num_classes]
    
    label_names = [cls for cls, _ in sorted_classes]
    label_to_idx = {cls: i for i, cls in enumerate(label_names)}
    
    print(f"Using {len(label_names)} classes with sufficient samples")
    data = []
    for path in image_paths:
        parts = path.replace(data_path, '').strip(os.sep).split(os.sep)
        if 'SUN397' in parts:
            sun_idx = parts.index('SUN397')
            if sun_idx + 1 < len(parts):
                class_name = parts[sun_idx + 1]
            else:
                continue
        else:
            class_name = parts[-2] if len(parts) > 1 else parts[0]
        
        class_name = class_name.replace('_', ' ').replace('-', ' ').strip()
        
        if class_name in label_to_idx:  # Only include classes we're using
            data.append((path, label_to_idx[class_name]))
    
    if not data:
        raise ValueError("No valid data samples could be created. Check the directory structure.")
    
    print(f"Created {len(data)} valid data samples")
    
    # Split into train and test sets (80% train, 20% test)
    try:
        train_data, test_data = train_test_split(
            data, test_size=0.2, random_state=42, stratify=[label for _, label in data]
        )
    except ValueError as e:
        print(f"Stratified split failed: {e}")
        print("Trying non-stratified split...")
        train_data, test_data = train_test_split(
            data, test_size=0.2, random_state=42
        )
    
    print(f"Split into {len(train_data)} training samples and {len(test_data)} test samples")
    print(f"Using {len(label_names)} classes from SUN dataset")
    
    return train_data, test_data, label_names

def load_image(image_path):
    img = Image.open(image_path).convert('RGB')
    return img

def main(argv):
    if not FLAGS.sun_path:
        raise ValueError("Please provide the path to SUN dataset using --sun_path flag")
    train_data, test_data, label_names = load_sun_dataset(FLAGS.sun_path, FLAGS.num_classes)
    model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    all_test_samples = test_data
    test_samples = all_test_samples[FLAGS.testi:FLAGS.teste]
    test_images = []
    test_labels = []
    for image_path, label_idx in test_samples:
        test_images.append(load_image(image_path))
        test_labels.append(label_names[label_idx])
    prompts = [f"a photo of a {label}" for label in test_labels]
    text_inputs = processor(text=prompts, return_tensors="tf", padding=True)

    text_features = model.get_text_features(**text_inputs)
    text_features = tf.nn.l2_normalize(text_features, axis=-1)
    dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
    inputs = processor(images=dummy_image, text=["a scene"], return_tensors="tf", padding=True)
    _ = model(**inputs)

    def freeze_all_layers(model):
        for var in model.variables:
            var._trainable = False
    freeze_all_layers(model)
    
    # for layer in model.clip.vision_model.embeddings.submodules:
    #     # Unfreeze the patch embedding (Conv2D)
    #     if isinstance(layer, tf.keras.layers.Conv2D):
    #         for var in layer.variables:
    #             var._trainable = True
    #             print(f"Unfroze patch encoder: {var.name}")
        
    #     # Unfreeze position embeddings if they exist in the embeddings
    #     if 'position_embedding' in layer.name.lower():
    #         for var in layer.variables:
    #             var._trainable = True
    #             print(f"Unfroze position embeddings: {var.name}")
        
    #     # Unfreeze any LayerNorm in the embeddings if they exist
    #     if isinstance(layer, tf.keras.layers.LayerNormalization):
    #         for var in layer.variables:
    #             var._trainable = True
    #             print(f"Unfroze embedding LayerNorm: {var.name}")

    # Unfreeze the first LayerNorm and attention in the first vision encoder block
    first_block = model.clip.vision_model.encoder.layers[0]
    unfrozen = 0
    for layer in first_block.submodules:
        if isinstance(layer, tf.keras.layers.LayerNormalization):
            for var in layer.variables:
                var._trainable = True
                print(f"Unfroze: {var.name}")
            unfrozen += 1
            break  # Stop after first LayerNorm
    
    # Unfreeze attention layer in the first vision encoder block
    # for layer in first_block.submodules:
    #     # Look for Multi-Head Attention layers
    #     if any('attention' in submodule.name.lower() for submodule in layer.submodules):
    #         for var in layer.variables:
    #             var._trainable = True
    #             print(f"Unfroze attention: {var.name}")

    trainable_vars = [v for v in model.variables if v.trainable]
    print(f"\nTotal trainable variables: {len(trainable_vars)}")
    for v in trainable_vars:
        print(v.name, v.shape)
    num_classes = len(label_names)
    proportions = np.random.dirichlet([FLAGS.betaa] * num_classes)
    max_index = np.argmax(proportions)
    
    sample_buckets = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16]  # Different possible sample counts
    
    sorted_indices = np.argsort(proportions)
    indices_without_max = [idx for idx in sorted_indices if idx != max_index]
    sample_counts = np.ones(num_classes, dtype=int)  # Start with 1 for all classes
    sample_counts[max_index] = FLAGS.maxsam  # Ensure the max class has specified max samples
    
    bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  # -1 because we already assigned max_samples
    for i, bucket_value in enumerate(sample_buckets[:-1]):  # Skip the last bucket
        start_idx = i * bucket_size
        end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
        for idx in indices_without_max[start_idx:end_idx]:
            sample_counts[idx] = bucket_value

    class_counter = defaultdict(int)
    train_images, train_labels = [], []
    train_by_class = defaultdict(list)
    for image_path, label_idx in train_data:
        label_str = label_names[label_idx]
        train_by_class[label_str].append(image_path)
    
    for i, label_str in enumerate(label_names):
        paths = train_by_class[label_str]
        for j in range(min(sample_counts[i], len(paths))):
            path = paths[j]
            train_images.append(load_image(path))
            train_labels.append(label_str)
            class_counter[label_str] += 1
    
    print("Actual samples per class:")
    for label_str in label_names:
        print(f"  {label_str}: {class_counter[label_str]}")
    
    train_prompts = [f"a photo of a {label}" for label in train_labels]
    train_inputs = processor(text=train_prompts, images=train_images, return_tensors="tf", padding=True)
    dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)
    optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate)
    
    def focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
        probs = tf.nn.softmax(logits, axis=-1)
        labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
        pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
        loss = -FLAGS.alpha * tf.pow(1. - pt, FLAGS.gamma) * tf.math.log(tf.clip_by_value(pt, 1e-9, 1.))
        return loss

    def contrastive_focal_loss(image_embeds, text_embeds, logit_scale, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
        # Compute similarity matrices with clipped scaling to prevent extreme values
        logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
        logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
        logits_per_text = tf.transpose(logits_per_image)

        batch_size = tf.shape(logits_per_image)[0]
        labels = tf.range(batch_size)

        def stable_focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
            probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
            labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
            pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
            focal_weight = tf.clip_by_value(tf.pow(1. - pt, FLAGS.gamma), 0.0, 10.0)
            ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
            loss = FLAGS.alpha * focal_weight * ce
            return loss

        loss_i2t = stable_focal_loss(logits_per_image, labels)
        loss_t2i = stable_focal_loss(logits_per_text, labels)

        loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2
        loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0
        return loss

    @tf.function
    def train_step(inputs):
        with tf.GradientTape() as tape:

            vision_outputs = model.clip.vision_model(
                inputs['pixel_values'],
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            image_embeds = model.clip.visual_projection(vision_outputs[1])
            input_shape = tf.shape(inputs['input_ids'])
            seq_length = input_shape[1]
            position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
            text_outputs = model.clip.text_model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask', None),
                position_ids=position_ids,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            text_embeds = model.clip.text_projection(text_outputs[1])
            image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)
            text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)
            logit_scale = tf.exp(model.clip.logit_scale)
            loss = contrastive_focal_loss(image_embeds, text_embeds, logit_scale, FLAGS.gamma, FLAGS.alpha)
            
        grads = tape.gradient(loss, trainable_vars)
        optimizer.apply_gradients(zip(grads, trainable_vars))
        return loss

    for epoch in range(FLAGS.tek):
        epoch_loss = 0
        for batch in dataset:
            batch_loss = train_step(batch)
            epoch_loss += batch_loss
        epoch_loss /= len(dataset)
        print(f"Epoch {epoch+1}, Loss: {epoch_loss.numpy():.4f}")

    prompt_templates = [
        "a photo of a {}.","a photograph of a {}.",
        "an image of a {}.","a picture of a {} scene.",
        "a scene showing {}.","a view of {}.",
        "a {} scene.","a {} view.",
        "a photo showing {}.","a snapshot of {}."
    ]
    all_prompts = []
    for label in test_labels:
        for template in prompt_templates:
            all_prompts.append(template.format(label))

    all_text_features = []
    for i in range(0, len(all_prompts), FLAGS.batch_size):
        batch_prompts = all_prompts[i:i+FLAGS.batch_size]
        text_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True)

        input_shape = tf.shape(text_inputs['input_ids'])
        seq_length = input_shape[1]
        position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]

        text_outputs = model.clip.text_model(
            input_ids=text_inputs['input_ids'],
            attention_mask=text_inputs.get('attention_mask', None),
            position_ids=position_ids,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )
        batch_text_features = model.clip.text_projection(text_outputs.pooler_output)
        batch_text_features = tf.nn.l2_normalize(batch_text_features, axis=-1)
        all_text_features.append(batch_text_features)

    all_text_features = tf.concat(all_text_features, axis=0)

    # Inference
    correct = 0
    for img, true_label in zip(test_images, test_labels):
        inputs = processor(images=img, return_tensors="tf", padding=True)
        vision_outputs = model.clip.vision_model(
            inputs['pixel_values'],
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )

        image_features = model.clip.visual_projection(vision_outputs.pooler_output)
        image_features = tf.nn.l2_normalize(image_features, axis=-1)
        sims = tf.matmul(image_features, all_text_features, transpose_b=True)
        class_scores = {}
        prompt_idx = 0
        for label in test_labels:
            class_scores[label] = 0
            for _ in prompt_templates:
                class_scores[label] += sims[0, prompt_idx].numpy()
                prompt_idx += 1
            class_scores[label] /= len(prompt_templates)
        pred_label = max(class_scores.items(), key=lambda x: x[1])[0]

        if pred_label == true_label:
            correct += 1
            
    print(f"\nAccuracy on test set: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")

if __name__ == "__main__":
    app.run(main)