import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
import transformers

from collections import defaultdict
import tensorflow_datasets as tfds
import sys
import numpy as np
from PIL import Image
import random
import os
import urllib.request
import tarfile
import shutil

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Training the first module')
flags.DEFINE_integer('teste', 1400, 'Training the second module')
flags.DEFINE_float('betaa', 0.1, 'Training the second module')
flags.DEFINE_integer('maxsam', 16, 'Training the first module')
flags.DEFINE_integer('tek', 100, 'Training the second module')

flags.DEFINE_float('learning_rate', 1e-3, 'Training the second module')
flags.DEFINE_float('alpha', 0.25, 'Training the first module')
flags.DEFINE_float('gamma', 1.0, 'Training the second module')

flags.DEFINE_string('data_dir', './data', 'Directory to store the dataset')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

def download_and_extract_aircraft_dataset(data_dir):
    os.makedirs(data_dir, exist_ok=True)
    dataset_url = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
    tar_file_path = os.path.join(data_dir, "fgvc-aircraft-2013b.tar.gz")   
    if not os.path.exists(os.path.join(data_dir, "fgvc-aircraft-2013b")):
        print("Downloading FGVC-Aircraft dataset...")
        urllib.request.urlretrieve(dataset_url, tar_file_path)        
        with tarfile.open(tar_file_path, "r:gz") as tar:
            tar.extractall(path=data_dir)        
        os.remove(tar_file_path)

def load_aircraft_data(data_dir, split='variant', subset='trainval'):
    dataset_path = os.path.join(data_dir, "fgvc-aircraft-2013b")
    images_dir = os.path.join(dataset_path, "data", "images")
    labels_file = os.path.join(dataset_path, "data", f"images_{split}_{subset}.txt")
    with open(labels_file, 'r') as f:
        content = f.readlines()    
    data = []
    label_set = set()
    for line in content:
        parts = line.strip().split()
        if len(parts) >= 2:
            image_id = parts[0]
            image_path = os.path.join(images_dir, f"{image_id}.jpg")
            label = ' '.join(parts[1:])  # Label might contain spaces
            label_set.add(label)
            data.append((image_path, label))
    label_to_idx = {label: idx for idx, label in enumerate(sorted(label_set))}    
    return data, list(sorted(label_set)), label_to_idx

def main(argv):
    download_and_extract_aircraft_dataset(FLAGS.data_dir)
    train_data, train_labels, label_to_idx = load_aircraft_data(FLAGS.data_dir, split='variant', subset='trainval')
    test_data, test_labels, _ = load_aircraft_data(FLAGS.data_dir, split='variant', subset='test')
    test_samples = test_data[FLAGS.testi:FLAGS.teste]
    test_images = [Image.open(img_path) for img_path, _ in test_samples]
    test_labels = [label for _, label in test_samples]    
    model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")    
    prompts = [f"a photo of a {label}" for label in test_labels]
    text_inputs = processor(text=prompts, return_tensors="tf", padding=True)
    text_features = model.get_text_features(**text_inputs)
    text_features = tf.nn.l2_normalize(text_features, axis=-1)
    dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
    inputs = processor(images=dummy_image, text=["an aircraft"], return_tensors="tf", padding=True)
    _ = model(**inputs)
    def freeze_all_layers(model):
        for var in model.variables:
            var._trainable = False
    freeze_all_layers(model)
    # for layer in model.clip.vision_model.embeddings.submodules:
        # Unfreeze the patch embedding (Conv2D)
        # if isinstance(layer, tf.keras.layers.Conv2D):
        #     for var in layer.variables:
        #         var._trainable = True
        #         print(f"Unfroze patch encoder: {var.name}")
        
        # # Unfreeze position embeddings if they exist in the embeddings
        # if 'position_embedding' in layer.name.lower():
        #     for var in layer.variables:
        #         var._trainable = True
        #         print(f"Unfroze position embeddings: {var.name}")
        
        # Unfreeze any LayerNorm in the embeddings if they exist
        # if isinstance(layer, tf.keras.layers.LayerNormalization):
        #     for var in layer.variables:
        #         var._trainable = True
        #         print(f"Unfroze embedding LayerNorm: {var.name}")

    # Unfreeze the first LayerNorm in the first vision encoder block
    first_block = model.clip.vision_model.encoder.layers[0]
    unfrozen = 0
    for layer in first_block.submodules:
        if isinstance(layer, tf.keras.layers.LayerNormalization):
            for var in layer.variables:
                var._trainable = True
                print(f"Unfroze: {var.name}")
            unfrozen += 1
            break  # Stop after first LayerNorm
    if unfrozen == 0:
        print("No LayerNorm was unfrozen!")

    # Confirm only those variables are trainable
    trainable_vars = [v for v in model.variables if v.trainable]
    print(f"\nTotal trainable variables: {len(trainable_vars)}")
    for v in trainable_vars:
        print(v.name, v.shape)

    # Draw proportions from the Dirichlet distribution.
    num_classes = len(train_labels)
    proportions = np.random.dirichlet([FLAGS.betaa] * num_classes)
    max_index = np.argmax(proportions)
    
    sample_buckets = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16]  
    sorted_indices = np.argsort(proportions)
    indices_without_max = [idx for idx in sorted_indices if idx != max_index]
    
    sample_counts = np.ones(num_classes, dtype=int) 
    sample_counts[max_index] = FLAGS.maxsam  
    bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  
    for i, bucket_value in enumerate(sample_buckets[:-1]):  
        start_idx = i * bucket_size
        end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
        for idx in indices_without_max[start_idx:end_idx]:
            sample_counts[idx] = bucket_value

    class_counter = defaultdict(int)
    train_images, train_labels_list = [], []
    random.shuffle(train_data)
    
    for image_path, label in train_data:
        if class_counter[label] < sample_counts[label_to_idx[label]]:
            img = Image.open(image_path)
            train_images.append(img)
            train_labels_list.append(label)
            class_counter[label] += 1
        if all(class_counter[label] >= sample_counts[label_to_idx[label]] for label in label_to_idx):
            break

    train_prompts = [f"a photo of a {label}" for label in train_labels_list]
    train_inputs = processor(text=train_prompts, images=train_images, return_tensors="tf", padding=True)
    dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)
    optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate)

    def contrastive_focal_loss(image_embeds, text_embeds, logit_scale, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
        logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
        logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
        logits_per_text = tf.transpose(logits_per_image)
        batch_size = tf.shape(logits_per_image)[0]
        labels = tf.range(batch_size)

        def stable_focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
            probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
            labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
            pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
            focal_weight = tf.clip_by_value(tf.pow(1. - pt, FLAGS.gamma), 0.0, 10.0)
            ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
            loss = FLAGS.alpha * focal_weight * ce
            return loss

        loss_i2t = stable_focal_loss(logits_per_image, labels)
        loss_t2i = stable_focal_loss(logits_per_text, labels)
        loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2
        loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0
        return loss

    @tf.function
    def train_step(inputs):
        with tf.GradientTape() as tape:
            # Process vision with parameters required by older version
            vision_outputs = model.clip.vision_model(
                inputs['pixel_values'],
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            image_embeds = model.clip.visual_projection(vision_outputs[1])            
            input_shape = tf.shape(inputs['input_ids'])
            seq_length = input_shape[1]
            position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
            text_outputs = model.clip.text_model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask', None),
                position_ids=position_ids,  # Add this parameter
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            text_embeds = model.clip.text_projection(text_outputs[1])
            image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)
            text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)
            logit_scale = tf.exp(model.clip.logit_scale)
            loss = contrastive_focal_loss(image_embeds, text_embeds, logit_scale, FLAGS.gamma, FLAGS.alpha)            
        grads = tape.gradient(loss, trainable_vars)
        optimizer.apply_gradients(zip(grads, trainable_vars))
        return loss

    for epoch in range(FLAGS.tek):
        epoch_loss = 0
        for batch in dataset:
            batch_loss = train_step(batch)
            epoch_loss += batch_loss
        epoch_loss /= len(dataset)
        print(f"Epoch {epoch+1}, Loss: {epoch_loss.numpy():.4f}")

    prompt_templates = [
        "a photo of a {}.", "a photograph of a {}.", "an image of a {}.",
        "a close-up photo of a {}.", "a bright photo of a {}.",
        "a cropped photo of a {}.", "a close-up image of a {}.",
        "a rendition of a {}.", "a side view of a {}.", "a detailed photo of a {}."
    ]
    all_prompts = []
    for label in test_labels:
        for template in prompt_templates:
            all_prompts.append(template.format(label))
    all_text_features = []
    for i in range(0, len(all_prompts), FLAGS.batch_size):
        batch_prompts = all_prompts[i:i+FLAGS.batch_size]
        text_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True)
        input_shape = tf.shape(text_inputs['input_ids'])
        seq_length = input_shape[1]
        position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
        text_outputs = model.clip.text_model(
            input_ids=text_inputs['input_ids'],
            attention_mask=text_inputs.get('attention_mask', None),
            position_ids=position_ids,  # Add this parameter
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )
        batch_text_features = model.clip.text_projection(text_outputs.pooler_output)
        batch_text_features = tf.nn.l2_normalize(batch_text_features, axis=-1)
        all_text_features.append(batch_text_features)

    all_text_features = tf.concat(all_text_features, axis=0)
    correct = 0
    for img, true_label in zip(test_images, test_labels):
        inputs = processor(images=img, return_tensors="tf", padding=True)
        vision_outputs = model.clip.vision_model(
            inputs['pixel_values'],
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )
        image_features = model.clip.visual_projection(vision_outputs.pooler_output)
        image_features = tf.nn.l2_normalize(image_features, axis=-1)
        sims = tf.matmul(image_features, all_text_features, transpose_b=True)        
        class_scores = {}
        prompt_idx = 0
        for label in test_labels:
            class_scores[label] = 0
            for _ in prompt_templates:
                class_scores[label] += sims[0, prompt_idx].numpy()
                prompt_idx += 1
            class_scores[label] /= len(prompt_templates)
        pred_label = max(class_scores.items(), key=lambda x: x[1])[0]
        if pred_label == true_label:
            correct += 1    
    print(f"\nAccuracy on test set: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")

if __name__ == "__main__":
    app.run(main)