import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
import transformers

from collections import defaultdict
import tensorflow_datasets as tfds
import sys
import numpy as np
from PIL import Image
import random
import os

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Training the first module')
flags.DEFINE_integer('teste', 1400, 'Training the second module')
flags.DEFINE_float('betaa', 0.1, 'Training the second module')
flags.DEFINE_integer('maxsam', 16, 'Training the first module')
flags.DEFINE_integer('tek', 100, 'Training the second module')

flags.DEFINE_float('learning_rate', 1e-3, 'Training the second module')
flags.DEFINE_float('alpha', 0.25, 'Training the first module')
flags.DEFINE_float('gamma', 1.0, 'Training the second module')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

def main(argv):
  # Load the Food-101 dataset instead of oxford_flowers102
  train_dataset = tfds.load("food101", split="train", as_supervised=True)
  test_dataset = tfds.load("food101", split="validation", as_supervised=True)
  label_names = [
      'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
      'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
      'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
      'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla',
      'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder',
      'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes',
      'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict',
      'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras',
      'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice',
      'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich',
      'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup',
      'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna',
      'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup',
      'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters',
      'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck',
      'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib',
      'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto',
      'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits',
      'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake',
      'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'
  ]
  
  model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  all_test_samples = list(tfds.as_numpy(test_dataset))
  test_samples = all_test_samples[FLAGS.testi:FLAGS.teste]
  test_images = [img for img, lbl in test_samples]
  test_labels = [label_names[lbl] for img, lbl in test_samples]
  prompts = [f"a photo of a {label}" for label in test_labels]
  text_inputs = processor(text=prompts, return_tensors="tf", padding=True)
  text_features = model.get_text_features(**text_inputs)
  text_features = tf.nn.l2_normalize(text_features, axis=-1)

  dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
  inputs = processor(images=dummy_image, text=["a food"], return_tensors="tf", padding=True)
  _ = model(**inputs)

  def freeze_all_layers(model):
      for var in model.variables:
          var._trainable = False
  freeze_all_layers(model)
  # for layer in model.clip.vision_model.embeddings.submodules:
  #     if isinstance(layer, tf.keras.layers.Conv2D):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze patch encoder: {var.name}")
  #     if 'position_embedding' in layer.name.lower():
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze position embeddings: {var.name}")
  #     if isinstance(layer, tf.keras.layers.LayerNormalization):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze embedding LayerNorm: {var.name}")

  first_block = model.clip.vision_model.encoder.layers[0]
  unfrozen = 0
  for layer in first_block.submodules:
      if isinstance(layer, tf.keras.layers.LayerNormalization):
          for var in layer.variables:
              var._trainable = True
              print(f"Unfroze: {var.name}")
          unfrozen += 1
          break  # Stop after first LayerNorm
  if unfrozen == 0:
      print("No LayerNorm was unfrozen!")

  # for layer in first_block.submodules:
  #     # Look for Multi-Head Attention layers
  #     if any('attention' in submodule.name.lower() for submodule in layer.submodules):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze attention: {var.name}")

  first_block = model.clip.vision_model.encoder.layers[0]; unfrozen = 0
  for layer in first_block.submodules:
      if isinstance(layer, tf.keras.layers.LayerNormalization):
          for var in layer.variables:
              var._trainable = True
              print(f"Unfroze: {var.name}")
          unfrozen += 1
          break  # Stop after first LayerNorm
  if unfrozen == 0:
      print("No LayerNorm was unfrozen!")

  trainable_vars = [v for v in model.variables if v.trainable]
  print(f"\nTotal trainable variables: {len(trainable_vars)}")
  for v in trainable_vars:
      print(v.name, v.shape)

  num_classes = len(label_names)
  proportions = np.random.dirichlet([FLAGS.betaa] * num_classes)
  max_index = np.argmax(proportions)

  sample_buckets = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16]  # Different possible sample counts
  sorted_indices = np.argsort(proportions)
  indices_without_max = [idx for idx in sorted_indices if idx != max_index]
  sample_counts = np.ones(num_classes, dtype=int)  # Start with 1 for all classes
  sample_counts[max_index] = FLAGS.maxsam  # Ensure the max class has 16 samples
  bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  # -1 because we already assigned max_samples
  for i, bucket_value in enumerate(sample_buckets[:-1]):  # Skip the last bucket (16)
      start_idx = i * bucket_size
      end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
      for idx in indices_without_max[start_idx:end_idx]:
          sample_counts[idx] = bucket_value

  class_counter = defaultdict(int)
  train_images, train_labels = [], []
  for image, label in tfds.as_numpy(train_dataset):
      label_str = label_names[label]
      if class_counter[label_str] < sample_counts[label]:
          train_images.append(image)
          train_labels.append(label_str)
          class_counter[label_str] += 1
      if all(class_counter[label_names[i]] >= sample_counts[i] for i in range(num_classes)):
          break
  
  train_prompts = [f"a photo of a {label}" for label in train_labels]
  train_inputs = processor(text=train_prompts, images=train_images, return_tensors="tf", padding=True)
  dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)

  optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate)
  def focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      probs = tf.nn.softmax(logits, axis=-1)
      labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
      pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
      loss = -FLAGS.alpha * tf.pow(1. - pt, FLAGS.gamma) * tf.math.log(tf.clip_by_value(pt, 1e-9, 1.))
      return loss

  def clip_contrastive_loss(image_embeds, text_embeds, logit_scale, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True)
      logits_per_text = tf.transpose(logits_per_image)
      # Apply temperature scaling
      logits_per_image *= logit_scale
      logits_per_text *= logit_scale
      batch_sizei = tf.shape(logits_per_image)[0]
      labelsi = tf.range(batch_sizei)
      loss_i2t = focal_loss(logits_per_image, labelsi, gamma=FLAGS.gamma, alpha=FLAGS.alpha)
      loss_t2i = focal_loss(logits_per_text, labelsi, gamma=FLAGS.gamma, alpha=FLAGS.alpha)
      return tf.reduce_mean(loss_i2t + loss_t2i) / 2

  def contrastive_focal_loss(image_embeds, text_embeds, logit_scale, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      # Compute similarity matrices with clipped scaling to prevent extreme values
      logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
      logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
      logits_per_text = tf.transpose(logits_per_image)

      batch_size = tf.shape(logits_per_image)[0]
      labels = tf.range(batch_size)

      def stable_focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
          probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
          labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
          pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
          focal_weight = tf.clip_by_value(tf.pow(1. - pt, FLAGS.gamma), 0.0, 10.0)
          ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
          loss = FLAGS.alpha * focal_weight * ce
          return loss

      loss_i2t = stable_focal_loss(logits_per_image, labels)
      loss_t2i = stable_focal_loss(logits_per_text, labels)

      loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2
      loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0
      return loss

  @tf.function
  def train_step(inputs):
      with tf.GradientTape() as tape:
          vision_outputs = model.clip.vision_model(
              inputs['pixel_values'],
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True
          )
          image_embeds = model.clip.visual_projection(vision_outputs[1])         
          input_shape = tf.shape(inputs['input_ids'])
          seq_length = input_shape[1]
          position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
          
          # Process text with all required parameters including position_ids
          text_outputs = model.clip.text_model(
              input_ids=inputs['input_ids'],
              attention_mask=inputs.get('attention_mask', None),
              position_ids=position_ids,  # Add this parameter
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True
          )
          text_embeds = model.clip.text_projection(text_outputs[1])
          image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)
          text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)
          logit_scale = tf.exp(model.clip.logit_scale)
          loss = contrastive_focal_loss(image_embeds, text_embeds, logit_scale, FLAGS.gamma, FLAGS.alpha)
          
      grads = tape.gradient(loss, trainable_vars)
      optimizer.apply_gradients(zip(grads, trainable_vars))
      return loss

  for epoch in range(FLAGS.tek):
      epoch_loss = 0
      for batch in dataset:
          batch_loss = train_step(batch)
          epoch_loss += batch_loss
      epoch_loss /= len(dataset)
      print(f"Epoch {epoch+1}, Loss: {epoch_loss.numpy():.4f}")

  prompt_templates = [
      "a photo of a {}.","a photograph of a {}.",
      "an image of a {}.","a close-up photo of a {}.",
      "a bright photo of a {}.","a cropped photo of a {}.",
      "a close-up image of a {}.","a rendition of a {}.",
      "a macro photo of a {}.","a detailed photo of a {}."
  ]
  all_prompts = []
  for label in test_labels:
      for template in prompt_templates:
          all_prompts.append(template.format(label))

  all_text_features = []
  for i in range(0, len(all_prompts), FLAGS.batch_size):
      batch_prompts = all_prompts[i:i+FLAGS.batch_size]
      text_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True)
      input_shape = tf.shape(text_inputs['input_ids'])
      seq_length = input_shape[1]
      position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
      text_outputs = model.clip.text_model(
          input_ids=text_inputs['input_ids'],
          attention_mask=text_inputs.get('attention_mask', None),
          position_ids=position_ids,  # Add this parameter
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
      )
      batch_text_features = model.clip.text_projection(text_outputs.pooler_output)
      batch_text_features = tf.nn.l2_normalize(batch_text_features, axis=-1)
      all_text_features.append(batch_text_features)

  all_text_features = tf.concat(all_text_features, axis=0)

  # Inference
  correct = 0
  for img, true_label in zip(test_images, test_labels):
      inputs = processor(images=img, return_tensors="tf", padding=True)
      vision_outputs = model.clip.vision_model(
          inputs['pixel_values'],
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
      )

      image_features = model.clip.visual_projection(vision_outputs.pooler_output)

      image_features = tf.nn.l2_normalize(image_features, axis=-1)
      sims = tf.matmul(image_features, all_text_features, transpose_b=True)
      class_scores = {}
      prompt_idx = 0
      for label in test_labels:
          class_scores[label] = 0
          for _ in prompt_templates:
              class_scores[label] += sims[0, prompt_idx].numpy()
              prompt_idx += 1
          class_scores[label] /= len(prompt_templates)

      pred_label = max(class_scores.items(), key=lambda x: x[1])[0]

      if pred_label == true_label:
          correct += 1
  print(f"\nAccuracy on test set: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")

if __name__ == "__main__":
  app.run(main)