import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
import transformers

from collections import defaultdict
import tensorflow_datasets as tfds
import sys
import numpy as np
from PIL import Image
import random
import os

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Training the first module')
flags.DEFINE_integer('teste', 1400, 'Training the second module')
flags.DEFINE_float('betaa', 0.1, 'Training the second module')
flags.DEFINE_integer('maxsam', 16, 'Training the first module')
flags.DEFINE_integer('tek', 100, 'Training the second module')

flags.DEFINE_float('learning_rate', 1e-3, 'Training the second module')
flags.DEFINE_float('alpha', 0.25, 'Training the first module')
flags.DEFINE_float('gamma', 1.0, 'Training the second module')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

def main(argv):
  train_dataset = tfds.load("caltech101", split="train", as_supervised=True)
  test_dataset = tfds.load("caltech101", split="test", as_supervised=True)
  dataset_info = tfds.builder("caltech101").info
  label_names = dataset_info.features["label"].names
  all_test_samples = list(tfds.as_numpy(test_dataset))
  test_size = min(len(all_test_samples), FLAGS.teste)
  testi = min(FLAGS.testi, test_size - 1)  # Ensure testi is within bounds
  teste = min(FLAGS.teste, test_size)      # Ensure teste is within bounds
  test_samples = all_test_samples[testi:teste]  
  test_images = [img for img, lbl in test_samples]
  test_labels = [label_names[lbl] for img, lbl in test_samples]
  prompts = [f"a photo of a {label}" for label in test_labels]
  
  model = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  
  text_inputs = processor(text=prompts, return_tensors="tf", padding=True)
  text_features = model.get_text_features(**text_inputs)
  text_features = tf.nn.l2_normalize(text_features, axis=-1)

  dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
  inputs = processor(images=dummy_image, text=["an object"], return_tensors="tf", padding=True)
  _ = model(**inputs)

  def freeze_all_layers(model):
      for var in model.variables:
          var._trainable = False
  freeze_all_layers(model)

  # for layer in model.clip.vision_model.embeddings.submodules:
  #     # Unfreeze the patch embedding (Conv2D)
  #     if isinstance(layer, tf.keras.layers.Conv2D):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze patch encoder: {var.name}")
      
  #     # Unfreeze position embeddings if they exist in the embeddings
  #     if 'position_embedding' in layer.name.lower():
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze position embeddings: {var.name}")
      
  #     # Unfreeze any LayerNorm in the embeddings if they exist
  #     if isinstance(layer, tf.keras.layers.LayerNormalization):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze embedding LayerNorm: {var.name}")

  # Unfreeze the first LayerNorm in the first vision encoder block
  first_block = model.clip.vision_model.encoder.layers[0]
  unfrozen = 0
  for layer in first_block.submodules:
      if isinstance(layer, tf.keras.layers.LayerNormalization):
          for var in layer.variables:
              var._trainable = True
              print(f"Unfroze: {var.name}")
          unfrozen += 1
          break  # Stop after first LayerNorm
  if unfrozen == 0:
      print("No LayerNorm was unfrozen!")

  # for layer in first_block.submodules:
  #     # Look for Multi-Head Attention layers
  #     if any('attention' in submodule.name.lower() for submodule in layer.submodules):
  #         for var in layer.variables:
  #             var._trainable = True
  #             print(f"Unfroze attention: {var.name}")

  trainable_vars = [v for v in model.variables if v.trainable]
  print(f"\nTotal trainable variables: {len(trainable_vars)}")
  for v in trainable_vars:
      print(v.name, v.shape)
  num_classes = len(label_names)
  proportions = np.random.dirichlet([FLAGS.betaa] * num_classes)
  max_index = np.argmax(proportions)
  sample_buckets = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16]  # Different possible sample counts
  # Sort the proportions (excluding the max) to create groups
  sorted_indices = np.argsort(proportions)
  indices_without_max = [idx for idx in sorted_indices if idx != max_index]
  sample_counts = np.ones(num_classes, dtype=int)  # Start with 1 for all classes
  sample_counts[max_index] = FLAGS.maxsam  # Ensure the max class has maximum samples
  bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  # -1 because we already assigned max_samples
  
  for i, bucket_value in enumerate(sample_buckets[:-1]):  # Skip the last bucket
      start_idx = i * bucket_size
      end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
      for idx in indices_without_max[start_idx:end_idx]:
          sample_counts[idx] = bucket_value
  class_counter = defaultdict(int)
  train_images, train_labels = [], []
  def preprocess_image(image):
      if image.shape[-1] == 1:  # If grayscale
          image = np.tile(image, [1, 1, 3])  # Convert to RGB
      
      if image.dtype != np.uint8:
          image = (image * 255).astype(np.uint8)
      
      pil_image = Image.fromarray(image)
      
      if pil_image.width != 224 or pil_image.height != 224:
          pil_image = pil_image.resize((224, 224), Image.BILINEAR)
      
      return np.array(pil_image)
  
  for image, label in tfds.as_numpy(train_dataset):
      label_str = label_names[label]
      if class_counter[label_str] < sample_counts[label]:
          processed_image = preprocess_image(image)
          train_images.append(processed_image)
          train_labels.append(label_str)
          class_counter[label_str] += 1
      
      if all(class_counter[label_names[i]] >= sample_counts[i] for i in range(num_classes)):
          break
  
  print(f"Collected training samples from {len(class_counter)} classes")
  
  train_prompts = [f"a photo of a {label}" for label in train_labels]
  train_inputs = processor(text=train_prompts, images=train_images, return_tensors="tf", padding=True)
  dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)

  optimizer = tf.keras.optimizers.Adam(FLAGS.learning_rate)
  
  def focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      probs = tf.nn.softmax(logits, axis=-1)
      labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
      pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
      loss = -FLAGS.alpha * tf.pow(1. - pt, FLAGS.gamma) * tf.math.log(tf.clip_by_value(pt, 1e-9, 1.))
      return loss

  def contrastive_focal_loss(image_embeds, text_embeds, logit_scale, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
      logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
      logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
      logits_per_text = tf.transpose(logits_per_image)

      batch_size = tf.shape(logits_per_image)[0]
      labels = tf.range(batch_size)

      def stable_focal_loss(logits, labels, gamma=FLAGS.gamma, alpha=FLAGS.alpha):
          probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
          labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
          pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
          focal_weight = tf.clip_by_value(tf.pow(1. - pt, FLAGS.gamma), 0.0, 10.0)
          ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
          loss = FLAGS.alpha * focal_weight * ce
          return loss

      loss_i2t = stable_focal_loss(logits_per_image, labels)
      loss_t2i = stable_focal_loss(logits_per_text, labels)

      loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2
      loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0
      return loss

  @tf.function
  def train_step(inputs):
      with tf.GradientTape() as tape:
          vision_outputs = model.clip.vision_model(
              inputs['pixel_values'],
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True
          )
          image_embeds = model.clip.visual_projection(vision_outputs[1])
          
          input_shape = tf.shape(inputs['input_ids'])
          seq_length = input_shape[1]
          position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
          text_outputs = model.clip.text_model(
              input_ids=inputs['input_ids'],
              attention_mask=inputs.get('attention_mask', None),
              position_ids=position_ids,  # Add this parameter
              output_attentions=False,
              output_hidden_states=False,
              return_dict=True
          )
          text_embeds = model.clip.text_projection(text_outputs[1])
          
          image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)
          text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)
          logit_scale = tf.exp(model.clip.logit_scale)
          loss = contrastive_focal_loss(image_embeds, text_embeds, logit_scale, FLAGS.gamma, FLAGS.alpha)
          
      grads = tape.gradient(loss, trainable_vars)
      optimizer.apply_gradients(zip(grads, trainable_vars))
      return loss

  for epoch in range(FLAGS.tek):
      epoch_loss = 0
      for batch in dataset:
          batch_loss = train_step(batch)
          epoch_loss += batch_loss
      epoch_loss /= len(dataset)
      print(f"Epoch {epoch+1}, Loss: {epoch_loss.numpy():.4f}")

  prompt_templates = [
      "a photo of a {}.","a photograph of a {}.","an image of a {}.",
      "a close-up photo of a {}.","a picture of a {}.",
      "a cropped photo of a {}.","a color photo of a {}.",
      "a rendition of a {}.","a detailed photo of a {}.",
      "an object called {}."
  ]

  all_prompts = []
  for label in test_labels:
      for template in prompt_templates:
          all_prompts.append(template.format(label))

  all_text_features = []
  for i in range(0, len(all_prompts), FLAGS.batch_size):
      batch_prompts = all_prompts[i:i+FLAGS.batch_size]
      text_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True)

      input_shape = tf.shape(text_inputs['input_ids'])
      seq_length = input_shape[1]
      position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]

      text_outputs = model.clip.text_model(
          input_ids=text_inputs['input_ids'],
          attention_mask=text_inputs.get('attention_mask', None),
          position_ids=position_ids,  # Add this parameter
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
      )
      batch_text_features = model.clip.text_projection(text_outputs.pooler_output)
      batch_text_features = tf.nn.l2_normalize(batch_text_features, axis=-1)
      all_text_features.append(batch_text_features)

  all_text_features = tf.concat(all_text_features, axis=0)

  processed_test_images = []
  for img in test_images:
      processed_test_images.append(preprocess_image(img))
  test_images = processed_test_images

  correct = 0
  for img, true_label in zip(test_images, test_labels):
      inputs = processor(images=img, return_tensors="tf", padding=True)
      vision_outputs = model.clip.vision_model(
          inputs['pixel_values'],
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
      )
      
      image_features = model.clip.visual_projection(vision_outputs.pooler_output)
      image_features = tf.nn.l2_normalize(image_features, axis=-1)
      
      sims = tf.matmul(image_features, all_text_features, transpose_b=True)
      
      class_scores = {}
      prompt_idx = 0
      for label in test_labels:
          class_scores[label] = 0
          for _ in prompt_templates:
              class_scores[label] += sims[0, prompt_idx].numpy()
              prompt_idx += 1
          class_scores[label] /= len(prompt_templates)

      pred_label = max(class_scores.items(), key=lambda x: x[1])[0]

      if pred_label == true_label:
          correct += 1
          
  print(f"\nAccuracy on test set: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")

if __name__ == "__main__":
  app.run(main)