import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
import transformers

from collections import defaultdict
import tensorflow_datasets as tfds
import sys
import numpy as np
from PIL import Image
import random
import os

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Training the first module')
flags.DEFINE_integer('teste', 1400, 'Training the second module')
flags.DEFINE_float('betaa', 0.1, 'Training the second module')
flags.DEFINE_integer('maxsam', 16, 'Training the first module')
flags.DEFINE_integer('tek', 100, 'Training the second module')

flags.DEFINE_float('learning_rate', 1e-3, 'Training the second module')
flags.DEFINE_float('alpha', 0.25, 'Training the first module')
flags.DEFINE_float('gamma', 1.0, 'Training the second module')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

def main(argv):
  train_dataset = tfds.load("oxford_iiit_pet", split="train", as_supervised=True)
  test_dataset = tfds.load("oxford_iiit_pet", split="test", as_supervised=True)
  label_names = [
      'Abyssinian', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound',
      'beagle', 'Bengal', 'Birman', 'Bombay', 'boxer', 'British_Shorthair',
      'chihuahua', 'Egyptian_Mau', 'english_cocker_spaniel', 'english_setter',
      'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin',
      'keeshond', 'leonberger', 'Maine_Coon', 'miniature_pinscher', 'newfoundland',
      'Persian', 'pomeranian', 'pug', 'Ragdoll', 'Russian_Blue', 'saint_bernard',
      'samoyed', 'scottish_terrier', 'shiba_inu', 'Siamese', 'Sphynx',
      'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier'
  ]
  
  model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  all_test_samples = list(tfds.as_numpy(test_dataset))
  test_samples_count = len(all_test_samples)
  if FLAGS.testi >= test_samples_count:
      FLAGS.testi = 0
  if FLAGS.teste > test_samples_count:
      FLAGS.teste = test_samples_count
  
  test_samples = all_test_samples[FLAGS.testi:FLAGS.teste]
  test_images = [img for img, lbl in test_samples]
  test_labels = [label_names[lbl] for img, lbl in test_samples]
  prompts = [f"a photo of a {label}" for label in test_labels]
  text_inputs = processor(text=prompts, return_tensors="tf", padding=True)
  text_features = model.get_text_features(**text_inputs)
  text_features = tf.nn.l2_normalize(text_features, axis=-1)
  dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
  inputs = processor(images=dummy_image, text=["a pet"], return_tensors="tf", padding=True)
  _ = model(**inputs)


  def freeze_all_layers(model):
      for var in model.variables:
          var._trainable = False
  freeze_all_layers(model)


  # for layer in model.clip.vision_model.embeddings.submodules:
      # # Unfreeze the patch embedding (Conv2D)
      # if isinstance(layer, tf.keras.layers.Conv2D):
      #     for var in layer.variables:
      #         var._trainable = True
      #         print(f"Unfroze patch encoder: {var.name}")
      
      # # Unfreeze position embeddings if they exist in the embeddings
      # if 'position_embedding' in layer.name.lower():
      #     for var in layer.variables:
      #         var._trainable = True
      #         print(f"Unfroze position embeddings: {var.name}")
      
      # # Unfreeze any LayerNorm in the embeddings if they exist
      # if isinstance(layer, tf.keras.layers.LayerNormalization):
      #     for var in layer.variables:
      #         var._trainable = True
      #         print(f"Unfroze embedding LayerNorm: {var.name}")

  # Unfreeze the first LayerNorm in the first vision encoder block
  first_block = model.clip.vision_model.encoder.layers[0]
  unfrozen = 0
  for layer in first_block.submodules:
      if isinstance(layer, tf.keras.layers.LayerNormalization):
          for var in layer.variables:
              var._trainable = True
              print(f"Unfroze: {var.name}")
          unfrozen += 1
          break  # Stop after first LayerNorm
  if unfrozen == 0:
      print("No LayerNorm was unfrozen!")

  # Unfreeze attention layer in the first vision encoder block
  # for layer in first_block.submodules:
  #     # Look for Multi-Head Attention layers
  #     if any('attention' in submodule.name.lower() for submodule in layer.submodules):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze attention: {var.name}")

  first_block = model.clip.vision_model.encoder.layers[0]; unfrozen = 0
  for layer in first_block.submodules:
      if isinstance(layer, tf.keras.layers.LayerNormalization):
          for var in layer.variables:
              var._trainable = True
              print(f"Unfroze: {var.name}")
          unfrozen += 1
          break  # Stop after first LayerNorm
  if unfrozen == 0:
      print("No LayerNorm was unfrozen!")

  # Confirm only those variables are trainable
  trainable_vars = [v for v in model.variables if v.trainable]
  print(f"\nTotal trainable variables: {len(trainable_vars)}")
  for v in trainable_vars:
      print(v.name, v.shape)

  num_classes = len(label_names)
  proportions = np.random.dirichlet([FLAGS.betaa] * num_classes)
  max_index = np.argmax(proportions)
  sample_buckets = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16]  # Different possible sample counts
  sorted_indices = np.argsort(proportions)
  indices_without_max = [idx for idx in sorted_indices if idx != max_index]
  sample_counts = np.ones(num_classes, dtype=int)  # Start with 1 for all classes
  sample_counts[max_index] = FLAGS.maxsam  # Ensure the max class has 16 samples
  bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  # -1 because we already assigned max_samples
  for i, bucket_value in enumerate(sample_buckets[:-1]):  # Skip the last bucket (16)
      start_idx = i * bucket_size
      end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
      for idx in indices_without_max[start_idx:end_idx]:
          sample_counts[idx] = bucket_value

  class_counter = defaultdict(int)
  train_images, train_labels = [], []
  for image, label in tfds.as_numpy(train_dataset):
      label_str = label_names[label]
      if class_counter[label_str] < sample_counts[label]:
          train_images.append(image)
          train_labels.append(label_str)
          class_counter[label_str] += 1
      if all(class_counter[label_names[i]] >= sample_counts[i] for i in range(num_classes)):
          break

  train_prompts = [f"a photo of a {label}" for label in train_labels]
  train_inputs = processor(text=train_prompts, images=train_images, return_tensors="tf", padding=True)
  dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)

  optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate)
  def focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      probs = tf.nn.softmax(logits, axis=-1)
      labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
      pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
      loss = -FLAGS.alpha * tf.pow(1. - pt, FLAGS.gamma) * tf.math.log(tf.clip_by_value(pt, 1e-9, 1.))
      return loss

  def clip_contrastive_loss(image_embeds, text_embeds, logit_scale, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True)
      logits_per_text = tf.transpose(logits_per_image)
      logits_per_image *= logit_scale
      logits_per_text *= logit_scale
      batch_sizei = tf.shape(logits_per_image)[0]
      labelsi = tf.range(batch_sizei)
      loss_i2t = focal_loss(logits_per_image, labelsi, gamma=FLAGS.gamma, alpha=FLAGS.alpha)
      loss_t2i = focal_loss(logits_per_text, labelsi, gamma=FLAGS.gamma, alpha=FLAGS.alpha)
      return tf.reduce_mean(loss_i2t + loss_t2i) / 2

  def contrastive_focal_loss(image_embeds, text_embeds, logit_scale, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
      logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
      logits_per_text = tf.transpose(logits_per_image)

      batch_size = tf.shape(logits_per_image)[0]
      labels = tf.range(batch_size)

      def stable_focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
          probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
          labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
          pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
          focal_weight = tf.clip_by_value(tf.pow(1. - pt, FLAGS.gamma), 0.0, 10.0)
          ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
          loss = FLAGS.alpha * focal_weight * ce
          return loss

      loss_i2t = stable_focal_loss(logits_per_image, labels)
      loss_t2i = stable_focal_loss(logits_per_text, labels)

      loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2
      loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0
      return loss

  @tf.function
  def train_step(inputs):
      with tf.GradientTape() as tape:
          vision_outputs = model.clip.vision_model(
              inputs['pixel_values'],
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True
          )
          image_embeds = model.clip.visual_projection(vision_outputs[1])
          input_shape = tf.shape(inputs['input_ids'])
          seq_length = input_shape[1]
          position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
          text_outputs = model.clip.text_model(
              input_ids=inputs['input_ids'],
              attention_mask=inputs.get('attention_mask', None),
              position_ids=position_ids,  # Add this parameter
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True
          )
          text_embeds = model.clip.text_projection(text_outputs[1])
          image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)
          text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)
          logit_scale = tf.exp(model.clip.logit_scale)
          loss = contrastive_focal_loss(image_embeds, text_embeds, logit_scale, FLAGS.gamma, FLAGS.alpha)
          
      grads = tape.gradient(loss, trainable_vars)
      optimizer.apply_gradients(zip(grads, trainable_vars))
      return loss

  for epoch in range(FLAGS.tek):
      epoch_loss = 0
      for batch in dataset:
          batch_loss = train_step(batch)
          epoch_loss += batch_loss
      epoch_loss /= len(dataset)
      print(f"Epoch {epoch+1}, Loss: {epoch_loss.numpy():.4f}")

  prompt_templates = [
      "a photo of a {}.","a photograph of a {}.",
      "an image of a {}.","a close-up photo of a {}.",
      "a picture of the {}.","a {} pet.",
      "a {} animal.","a cute {}.",
      "a portrait of a {}.","a domestic {}."
  ]
  all_prompts = []
  for label in test_labels:
      for template in prompt_templates:
          all_prompts.append(template.format(label))
  all_text_features = []
  for i in range(0, len(all_prompts), FLAGS.batch_size):
      batch_prompts = all_prompts[i:i+FLAGS.batch_size]
      text_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True)
      input_shape = tf.shape(text_inputs['input_ids'])
      seq_length = input_shape[1]
      position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
      text_outputs = model.clip.text_model(
          input_ids=text_inputs['input_ids'],
          attention_mask=text_inputs.get('attention_mask', None),
          position_ids=position_ids,  # Add this parameter
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
      )
      batch_text_features = model.clip.text_projection(text_outputs.pooler_output)
      batch_text_features = tf.nn.l2_normalize(batch_text_features, axis=-1)
      all_text_features.append(batch_text_features)

  all_text_features = tf.concat(all_text_features, axis=0)

  # Inference
  correct = 0
  for img, true_label in zip(test_images, test_labels):
      inputs = processor(images=img, return_tensors="tf", padding=True)
      vision_outputs = model.clip.vision_model(
          inputs['pixel_values'],
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
      )
      image_features = model.clip.visual_projection(vision_outputs.pooler_output)

      image_features = tf.nn.l2_normalize(image_features, axis=-1)
      sims = tf.matmul(image_features, all_text_features, transpose_b=True)
      class_scores = {}
      prompt_idx = 0
      for label in test_labels:
          class_scores[label] = 0
          for _ in prompt_templates:
              class_scores[label] += sims[0, prompt_idx].numpy()
              prompt_idx += 1
          class_scores[label] /= len(prompt_templates)

      pred_label = max(class_scores.items(), key=lambda x: x[1])[0]

      if pred_label == true_label:
          correct += 1
  print(f"\nAccuracy on test set: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")

if __name__ == "__main__":
  app.run(main)