import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
import transformers

from collections import defaultdict
import tensorflow_datasets as tfds
import sys
import numpy as np
from PIL import Image
import random
import os

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Training the first module')
flags.DEFINE_integer('teste', 1400, 'Training the second module')
flags.DEFINE_float('betaa', 0.1, 'Training the second module')
flags.DEFINE_integer('maxsam', 16, 'Training the first module')
flags.DEFINE_integer('tek', 100, 'Training the second module')

flags.DEFINE_float('learning_rate', 1e-3, 'Training the second module')
flags.DEFINE_float('alpha', 0.25, 'Training the first module')
flags.DEFINE_float('gamma', 1.0, 'Training the second module')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

def main(argv):
  train_dataset = tfds.load("cifar100", split="train", as_supervised=True)
  test_dataset = tfds.load("cifar100", split="test", as_supervised=True)
  label_names = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 
    'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 
    'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 
    'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 
    'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 
    'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
    'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
    'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
    'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
    'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
    'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
    'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
    'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
    'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
    'worm'
  ]
  
  # Pre-processing function to resize CIFAR images (32x32) to CLIP input size (224x224)
  def preprocess_image(image, label):
    image = tf.image.resize(image, (224, 224))
    return image, label
  train_dataset = train_dataset.map(preprocess_image)
  test_dataset = test_dataset.map(preprocess_image)
  
  model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  
  all_test_samples = list(tfds.as_numpy(test_dataset))
  test_samples = all_test_samples[FLAGS.testi:FLAGS.teste]
  test_images = [img for img, lbl in test_samples]
  test_labels = [label_names[lbl] for img, lbl in test_samples]
  prompts = [f"a photo of a {label}" for label in test_labels]
  text_inputs = processor(text=prompts, return_tensors="tf", padding=True)
  text_features = model.get_text_features(**text_inputs)
  text_features = tf.nn.l2_normalize(text_features, axis=-1)

  dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
  inputs = processor(images=dummy_image, text=["a photo"], return_tensors="tf", padding=True)
  _ = model(**inputs)

  def freeze_all_layers(model):
      for var in model.variables:
          var._trainable = False
  freeze_all_layers(model)

  # for layer in model.clip.vision_model.embeddings.submodules:
  #     # Unfreeze the patch embedding (Conv2D)
  #     if isinstance(layer, tf.keras.layers.Conv2D):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze patch encoder: {var.name}")
      
  #     # Unfreeze position embeddings if they exist in the embeddings
  #     if 'position_embedding' in layer.name.lower():
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze position embeddings: {var.name}")
      
  #     # Unfreeze any LayerNorm in the embeddings if they exist
  #     if isinstance(layer, tf.keras.layers.LayerNormalization):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze embedding LayerNorm: {var.name}")

  # Unfreeze the first LayerNorm in the first vision encoder block
  first_block = model.clip.vision_model.encoder.layers[0]
  unfrozen = 0
  for layer in first_block.submodules:
      if isinstance(layer, tf.keras.layers.LayerNormalization):
          for var in layer.variables:
              var._trainable = True
              print(f"Unfroze: {var.name}")
          unfrozen += 1
          break  # Stop after first LayerNorm
  if unfrozen == 0:
      print("No LayerNorm was unfrozen!")

  # # Unfreeze attention layer in the first vision encoder block
  # for layer in first_block.submodules:
  #     # Look for Multi-Head Attention layers
  #     if any('attention' in submodule.name.lower() for submodule in layer.submodules):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze attention: {var.name}")


  trainable_vars = [v for v in model.variables if v.trainable]
  print(f"\nTotal trainable variables: {len(trainable_vars)}")
  for v in trainable_vars:
      print(v.name, v.shape)

  num_classes = len(label_names)
  proportions = np.random.dirichlet([FLAGS.betaa] * num_classes)
  max_index = np.argmax(proportions)
  sample_buckets = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16]  # Different possible sample counts
  sorted_indices = np.argsort(proportions)
  indices_without_max = [idx for idx in sorted_indices if idx != max_index]
  sample_counts = np.ones(num_classes, dtype=int)  # Start with 1 for all classes
  sample_counts[max_index] = FLAGS.maxsam  # Ensure the max class has 16 samples
  bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  # -1 because we already assigned max_samples
  for i, bucket_value in enumerate(sample_buckets[:-1]):  # Skip the last bucket (16)
      start_idx = i * bucket_size
      end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
      for idx in indices_without_max[start_idx:end_idx]:
          sample_counts[idx] = bucket_value
  class_counter = defaultdict(int)
  train_images, train_labels = [], []
  for image, label in tfds.as_numpy(train_dataset):
      label_str = label_names[label]
      # Use the integer label as index to sample_counts.
      if class_counter[label_str] < sample_counts[label]:
          train_images.append(image)
          train_labels.append(label_str)
          class_counter[label_str] += 1
      if all(class_counter[label_names[i]] >= sample_counts[i] for i in range(num_classes)):
          break

  train_prompts = [f"a photo of a {label}" for label in train_labels]
  train_inputs = processor(text=train_prompts, images=train_images, return_tensors="tf", padding=True)
  dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)

  optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate)
  
  def focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      probs = tf.nn.softmax(logits, axis=-1)
      labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
      pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
      loss = -FLAGS.alpha * tf.pow(1. - pt, FLAGS.gamma) * tf.math.log(tf.clip_by_value(pt, 1e-9, 1.))
      return loss

  def contrastive_focal_loss(image_embeds, text_embeds, logit_scale, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
      logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
      logits_per_text = tf.transpose(logits_per_image)

      batch_size = tf.shape(logits_per_image)[0]
      labels = tf.range(batch_size)

      def stable_focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
          probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
          labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
          pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
          focal_weight = tf.clip_by_value(tf.pow(1. - pt, FLAGS.gamma), 0.0, 10.0)
          ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
          loss = FLAGS.alpha * focal_weight * ce
          return loss

      loss_i2t = stable_focal_loss(logits_per_image, labels)
      loss_t2i = stable_focal_loss(logits_per_text, labels)
      loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2
      loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0
      return loss

  @tf.function
  def train_step(inputs):
      with tf.GradientTape() as tape:
          vision_outputs = model.clip.vision_model(
              inputs['pixel_values'],
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True
          )
          image_embeds = model.clip.visual_projection(vision_outputs[1])

          input_shape = tf.shape(inputs['input_ids'])
          seq_length = input_shape[1]
          position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]

          text_outputs = model.clip.text_model(
              input_ids=inputs['input_ids'],
              attention_mask=inputs.get('attention_mask', None),
              position_ids=position_ids,  # Add this parameter
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True
          )
          text_embeds = model.clip.text_projection(text_outputs[1])

          image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)
          text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)
          logit_scale = tf.exp(model.clip.logit_scale)
          loss = contrastive_focal_loss(image_embeds, text_embeds, logit_scale, FLAGS.gamma, FLAGS.alpha)
          
      grads = tape.gradient(loss, trainable_vars)
      optimizer.apply_gradients(zip(grads, trainable_vars))
      return loss

  for epoch in range(FLAGS.tek):
      epoch_loss = 0
      for batch in dataset:
          batch_loss = train_step(batch)
          epoch_loss += batch_loss
      epoch_loss /= len(dataset)
      print(f"Epoch {epoch+1}, Loss: {epoch_loss.numpy():.4f}")

  prompt_templates = [
      "a photo of a {}.","a photograph of a {}.",
      "an image of a {}.","a picture of a {}.",
      "a bright photo of a {}.","a colorful image of a {}.",
      "a close-up photo of a {}.","a rendition of a {}.",
      "a detailed photo of a {}.","a clear image of a {}."
  ]
  all_prompts = []
  for label in test_labels:
      for template in prompt_templates:
          all_prompts.append(template.format(label))

  all_text_features = []
  for i in range(0, len(all_prompts), FLAGS.batch_size):
      batch_prompts = all_prompts[i:i+FLAGS.batch_size]
      text_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True)

      input_shape = tf.shape(text_inputs['input_ids'])
      seq_length = input_shape[1]
      position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]

      text_outputs = model.clip.text_model(
          input_ids=text_inputs['input_ids'],
          attention_mask=text_inputs.get('attention_mask', None),
          position_ids=position_ids,  # Add this parameter
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
      )
      batch_text_features = model.clip.text_projection(text_outputs.pooler_output)
      batch_text_features = tf.nn.l2_normalize(batch_text_features, axis=-1)
      all_text_features.append(batch_text_features)

  all_text_features = tf.concat(all_text_features, axis=0)

  # Inference
  correct = 0
  for img, true_label in zip(test_images, test_labels):
      inputs = processor(images=img, return_tensors="tf", padding=True)
      vision_outputs = model.clip.vision_model(
          inputs['pixel_values'],
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
      )

      image_features = model.clip.visual_projection(vision_outputs.pooler_output)
      image_features = tf.nn.l2_normalize(image_features, axis=-1)
      sims = tf.matmul(image_features, all_text_features, transpose_b=True)
      
      class_scores = {}
      prompt_idx = 0
      for label in test_labels:
          class_scores[label] = 0
          for _ in prompt_templates:
              class_scores[label] += sims[0, prompt_idx].numpy()
              prompt_idx += 1
          class_scores[label] /= len(prompt_templates)
      pred_label = max(class_scores.items(), key=lambda x: x[1])[0]

      if pred_label == true_label:
          correct += 1
  print(f"\nAccuracy on test set: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")

if __name__ == "__main__":
  app.run(main)