import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
import transformers

from collections import defaultdict
import tensorflow_datasets as tfds
import sys
import numpy as np
from PIL import Image
import random
import os

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Training the first module')
flags.DEFINE_integer('teste', 1400, 'Training the second module')
flags.DEFINE_float('betaa', 0.1, 'Training the second module')
flags.DEFINE_integer('maxsam', 16, 'Training the first module')
flags.DEFINE_integer('tek', 100, 'Training the second module')

flags.DEFINE_float('learning_rate', 1e-3, 'Training the second module')
flags.DEFINE_float('alpha', 0.25, 'Training the first module')
flags.DEFINE_float('gamma', 1.0, 'Training the second module')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

def main(argv):
  train_dataset = tfds.load("oxford_flowers102", split="train", as_supervised=True)
  test_dataset = tfds.load("oxford_flowers102", split="test", as_supervised=True)
  label_names = [
      'pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',
      'english marigold', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood',
      'globe thistle', 'snapdragon', "colt's foot", 'king protea', 'spear thistle',
      'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower',
      'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary', 'red ginger',
      'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian',
      'artichoke', 'sweet william', 'carnation', 'garden phlox', 'love in the mist', 'mexican aster',
      'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip',
      'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue',
      'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion', 'petunia',
      'wild pansy', 'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura',
      'geranium', 'orange dahlia', 'pink-yellow dahlia', 'cautleya spicata', 'japanese anemone',
      'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus',
      'bearded iris', 'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose',
      'thorn apple', 'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium',
      'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose', 'tree mallow',
      'magnolia', 'cyclamen', 'watercress', 'canna lily', 'hippeastrum', 'bee balm',
      'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia',
      'bromelia', 'blanket flower', 'trumpet creeper', 'blackberry lily'
  ]
  model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  all_test_samples = list(tfds.as_numpy(test_dataset))
  test_samples = all_test_samples[FLAGS.testi:FLAGS.teste]
  test_images = [img for img, lbl in test_samples]
  test_labels = [label_names[lbl] for img, lbl in test_samples]
  prompts = [f"a photo of a {label}" for label in test_labels]
  text_inputs = processor(text=prompts, return_tensors="tf", padding=True)
  text_features = model.get_text_features(**text_inputs)
  text_features = tf.nn.l2_normalize(text_features, axis=-1)

  # Zero-Shot
  # correct = 0
  # for img, true_label in zip(test_images, test_labels):
  #     inputs = processor(images=img, return_tensors="tf", padding=True)
  #     image_features = model.get_image_features(**inputs)
  #     image_features = tf.nn.l2_normalize(image_features, axis=-1)
  #     sims = tf.matmul(image_features, text_features, transpose_b=True)
  #     pred_idx = tf.argmax(sims, axis=1).numpy()[0]
  #     pred_label = test_labels[pred_idx]
  #     if pred_label == true_label:
  #         correct += 1
  # print(f"\nZero-Shot Accuracy: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")

  # Run a forward pass to build weights
  dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
  inputs = processor(images=dummy_image, text=["a flower"], return_tensors="tf", padding=True)
  _ = model(**inputs)

  # Fully disable training on all variables
  def freeze_all_layers(model):
      for var in model.variables:
          var._trainable = False
  freeze_all_layers(model)

  # Step 4: Unfreeze ONLY the first LayerNorm in the first text encoder block
  # print('len', len(model.clip.text_model.encoder.layers))
  # first_block = model.clip.text_model.encoder.layers[0:11]
  # unfrozen = 0
  # for i in range(len(first_block)):
  #     for layer in first_block[i].submodules:
  #         if isinstance(layer, tf.keras.layers.LayerNormalization):
  #             for var in layer.variables:
  #                 var._trainable = True
  #                 print(f"Unfroze: {var.name}")
  #             unfrozen += 1
  #             break  # Stop after first LayerNorm
  #     if unfrozen == 0:
  #         print("No LayerNorm was unfrozen!")

  #Unfreeze ONLY the first LayerNorm in the first vision encoder block

  # for layer in model.clip.vision_model.embeddings.submodules:
  #     # Unfreeze the patch embedding (Conv2D)
  #     if isinstance(layer, tf.keras.layers.Conv2D):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze patch encoder: {var.name}")
      
  #     # Unfreeze position embeddings if they exist in the embeddings
  #     if 'position_embedding' in layer.name.lower():
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze position embeddings: {var.name}")
      
  #     # Unfreeze any LayerNorm in the embeddings if they exist
  #     if isinstance(layer, tf.keras.layers.LayerNormalization):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze embedding LayerNorm: {var.name}")

  # Unfreeze the first LayerNorm in the first vision encoder block
  first_block = model.clip.vision_model.encoder.layers[0]
  unfrozen = 0
  for layer in first_block.submodules:
      if isinstance(layer, tf.keras.layers.LayerNormalization):
          for var in layer.variables:
              var._trainable = True
              print(f"Unfroze: {var.name}")
          unfrozen += 1
          break  # Stop after first LayerNorm
  if unfrozen == 0:
      print("No LayerNorm was unfrozen!")

  # for layer in first_block.submodules:
  #     # Look for Multi-Head Attention layers
  #     if any('attention' in submodule.name.lower() for submodule in layer.submodules):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze attention: {var.name}")



  first_block = model.clip.vision_model.encoder.layers[0]; unfrozen = 0
  # for i in range(len(first_block)):
  for layer in first_block.submodules:
      if isinstance(layer, tf.keras.layers.LayerNormalization):
          for var in layer.variables:
              var._trainable = True
              print(f"Unfroze: {var.name}")
          unfrozen += 1
          break  # Stop after first LayerNorm
  if unfrozen == 0:
      print("No LayerNorm was unfrozen!")

  #Confirm only those variables are trainable
  trainable_vars = [v for v in model.variables if v.trainable]
  print(f"\nTotal trainable variables: {len(trainable_vars)}")
  for v in trainable_vars:
      print(v.name, v.shape)


  num_classes = len(label_names)
  proportions = np.random.dirichlet([FLAGS.betaa] * num_classes)
  max_index = np.argmax(proportions)
  sample_buckets = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16]  # Different possible sample counts
  sorted_indices = np.argsort(proportions)
  indices_without_max = [idx for idx in sorted_indices if idx != max_index]
  sample_counts = np.ones(num_classes, dtype=int)  # Start with 1 for all classes
  sample_counts[max_index] = FLAGS.maxsam  # Ensure the max class has 16 samples
  bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  # -1 because we already assigned max_samples
  for i, bucket_value in enumerate(sample_buckets[:-1]):  # Skip the last bucket (16)
      start_idx = i * bucket_size
      end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
      for idx in indices_without_max[start_idx:end_idx]:
          sample_counts[idx] = bucket_value
  # print("Desired sample counts per class:")
  # for i, count in enumerate(sample_counts):
  #     print(f"  {label_names[i]}: {count}")

  # Collect training samples based on these counts.
  class_counter = defaultdict(int)
  train_images, train_labels = [], []
  for image, label in tfds.as_numpy(train_dataset):
      label_str = label_names[label]
      # Use the integer label as index to sample_counts.
      if class_counter[label_str] < sample_counts[label]:
          train_images.append(image)
          train_labels.append(label_str)
          class_counter[label_str] += 1
      # Once we have enough samples for every class, stop.
      if all(class_counter[label_names[i]] >= sample_counts[i] for i in range(num_classes)):
          break
  # print("Actual samples per class:")
  # for label in label_names:
  #     print(f"  {label}: {class_counter[label]}")
  train_prompts = [f"a photo of a {label}" for label in train_labels]
  train_inputs = processor(text=train_prompts, images=train_images, return_tensors="tf", padding=True)
  dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)

  # loss_fn = tf.keras.losses.CosineSimilarity(axis=-1)
  optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate)
  def focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      probs = tf.nn.softmax(logits, axis=-1)
      labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
      pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
      loss = -FLAGS.alpha * tf.pow(1. - pt, FLAGS.gamma) * tf.math.log(tf.clip_by_value(pt, 1e-9, 1.))
      return loss

  def clip_contrastive_loss(image_embeds, text_embeds, logit_scale, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True)
      logits_per_text = tf.transpose(logits_per_image)
      # Apply temperature scaling
      logits_per_image *= logit_scale
      logits_per_text *= logit_scale
      batch_sizei = tf.shape(logits_per_image)[0]
      labelsi = tf.range(batch_sizei)
      loss_i2t = focal_loss(logits_per_image, labelsi, gamma=FLAGS.gamma, alpha=FLAGS.alpha)
      loss_t2i = focal_loss(logits_per_text, labelsi, gamma=FLAGS.gamma, alpha=FLAGS.alpha)
      return tf.reduce_mean(loss_i2t + loss_t2i) / 2



  def contrastive_focal_loss(image_embeds, text_embeds, logit_scale, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      # Compute similarity matrices with clipped scaling to prevent extreme values
      logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
      logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
      logits_per_text = tf.transpose(logits_per_image)

      batch_size = tf.shape(logits_per_image)[0]
      labels = tf.range(batch_size)

      def stable_focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
          probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
          labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
          pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
          focal_weight = tf.clip_by_value(tf.pow(1. - pt, FLAGS.gamma), 0.0, 10.0)
          ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
          loss = FLAGS.alpha * focal_weight * ce
          return loss

      loss_i2t = stable_focal_loss(logits_per_image, labels)
      loss_t2i = stable_focal_loss(logits_per_text, labels)

      loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2
      loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0
      return loss

  @tf.function
  def train_step(inputs):
      with tf.GradientTape() as tape:
          vision_outputs = model.clip.vision_model(
              inputs['pixel_values'],
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True
          )
          image_embeds = model.clip.visual_projection(vision_outputs[1])

          input_shape = tf.shape(inputs['input_ids'])
          seq_length = input_shape[1]
          position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]

          text_outputs = model.clip.text_model(
              input_ids=inputs['input_ids'],
              attention_mask=inputs.get('attention_mask', None),
              position_ids=position_ids,  # Add this parameter
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True
          )
          text_embeds = model.clip.text_projection(text_outputs[1])

          image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)
          text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)
          logit_scale = tf.exp(model.clip.logit_scale)
          loss = contrastive_focal_loss(image_embeds, text_embeds, logit_scale, FLAGS.gamma, FLAGS.alpha)
          
      grads = tape.gradient(loss, trainable_vars)
      optimizer.apply_gradients(zip(grads, trainable_vars))
      return loss

  for epoch in range(FLAGS.tek):
      epoch_loss = 0
      for batch in dataset:
          batch_loss = train_step(batch)
          epoch_loss += batch_loss
      epoch_loss /= len(dataset)
      print(f"Epoch {epoch+1}, Loss: {epoch_loss.numpy():.4f}")

  prompt_templates = [
      "a photo of a {}.","a photograph of a {}.",
      "an image of a {}.","a close-up photo of a {}.",
      "a bright photo of a {}.","a cropped photo of a {}.",
      "a close-up image of a {}.","a rendition of a {}.",
      "a macro photo of a {}.","a detailed photo of a {}."
  ]
  all_prompts = []
  for label in test_labels:
      for template in prompt_templates:
          all_prompts.append(template.format(label))

  # Process all prompts in batches to avoid memory issues
  # Process all prompts in batches
  all_text_features = []
  for i in range(0, len(all_prompts), FLAGS.batch_size):
      batch_prompts = all_prompts[i:i+FLAGS.batch_size]
      text_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True)
      input_shape = tf.shape(text_inputs['input_ids'])
      seq_length = input_shape[1]
      position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
      text_outputs = model.clip.text_model(
          input_ids=text_inputs['input_ids'],
          attention_mask=text_inputs.get('attention_mask', None),
          position_ids=position_ids,  # Add this parameter
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
      )
      batch_text_features = model.clip.text_projection(text_outputs.pooler_output)
      batch_text_features = tf.nn.l2_normalize(batch_text_features, axis=-1)
      all_text_features.append(batch_text_features)

  all_text_features = tf.concat(all_text_features, axis=0)

  # Inference
  correct = 0
  for img, true_label in zip(test_images, test_labels):
      inputs = processor(images=img, return_tensors="tf", padding=True)
      vision_outputs = model.clip.vision_model(
          inputs['pixel_values'],
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
      )
      image_features = model.clip.visual_projection(vision_outputs.pooler_output)
      image_features = tf.nn.l2_normalize(image_features, axis=-1)
      sims = tf.matmul(image_features, all_text_features, transpose_b=True)
      # sims = tf.matmul(image_features, text_features, transpose_b=True)
      class_scores = {}
      prompt_idx = 0
      for label in test_labels:
          class_scores[label] = 0
          for _ in prompt_templates:
              class_scores[label] += sims[0, prompt_idx].numpy()
              prompt_idx += 1
          class_scores[label] /= len(prompt_templates)
      pred_label = max(class_scores.items(), key=lambda x: x[1])[0]

      if pred_label == true_label:
          correct += 1
  print(f"\nAccuracy on test set: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")

if __name__ == "__main__":
  app.run(main)