import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
from collections import defaultdict
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import random
import os
import copy
import json

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Training the first module')
flags.DEFINE_integer('teste', 1400, 'Training the second module')
flags.DEFINE_float('betaa', 0.1, 'Training the second module')
flags.DEFINE_integer('maxsam', 16, 'Training the first module')
flags.DEFINE_integer('tek', 100, 'Training the second module')

flags.DEFINE_float('learning_rate', 1e-5, 'Training the second module')
flags.DEFINE_float('alpha', 0.25, 'Training the first module')
flags.DEFINE_float('gamma', 1.0, 'Training the second module')

flags.DEFINE_integer('NUM_CLIENTS', 2, 'Training the second module')
flags.DEFINE_integer('COMMUNICATION_ROUNDS', 100, 'Training the second module')
flags.DEFINE_integer('LOCAL_EPOCHS', 1, 'Training the second module')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

class CLIPFederatedServer:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model_name = model_name
        self.global_model = TFCLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
        inputs = self.processor(images=dummy_image, text=["a flower"], return_tensors="tf", padding=True)
        _ = self.global_model(**inputs)
        self._freeze_all_layers(self.global_model)
        self._unfreeze_first_layernorm()
        self.trainable_vars = [v for v in self.global_model.variables if v.trainable]
        print(f"Server initialized with {len(self.trainable_vars)} trainable variables:")
        for v in self.trainable_vars:
            print(f"  {v.name}, {v.shape}")
    
    
    def _freeze_all_layers(self, model):
        for var in model.variables:
            var._trainable = False
    
    def _unfreeze_first_layernorm(self):
        # for layer in self.global_model.clip.vision_model.embeddings.submodules:
        #     if isinstance(layer, tf.keras.layers.Conv2D):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze patch encoder: {var.name}")
        #     if 'position_embedding' in layer.name.lower():
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze position embeddings: {var.name}")
        #     if isinstance(layer, tf.keras.layers.LayerNormalization):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze embedding LayerNorm: {var.name}")

        first_block = self.global_model.clip.vision_model.encoder.layers[0]
        unfrozen = 0
        for layer in first_block.submodules:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze: {var.name}")
                unfrozen += 1
                break  # Stop after first LayerNorm
        if unfrozen == 0:
            print("No LayerNorm was unfrozen!")
    
    def get_model_weights(self):
        return [var.numpy() for var in self.trainable_vars]
    
    def update_global_model(self, client_weights_list, client_samples_list):
        total_samples = sum(client_samples_list)
        avg_weights = [np.zeros_like(w) for w in client_weights_list[0]]
        for i, client_weights in enumerate(client_weights_list):
            weight_ratio = client_samples_list[i] / total_samples
            for j, w in enumerate(client_weights):
                avg_weights[j] += w * weight_ratio
        for var, avg_w in zip(self.trainable_vars, avg_weights):
            var.assign(avg_w)
        
        print(f"Global model updated with weights from {len(client_weights_list)} clients")
        return avg_weights
    
    def distribute_model(self, clients):
        """Distribute the current global model to all clients"""
        global_weights = self.get_model_weights()
        for client in clients:
            client.update_local_model(global_weights)


class CLIPFederatedClient:
    def __init__(self, client_id, model_name="openai/clip-vit-base-patch32"):
        self.client_id = client_id
        self.model = TFCLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.train_images = []
        self.train_labels = []
        self.num_samples = 0

        dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
        inputs = self.processor(images=dummy_image, text=["a flower"], return_tensors="tf", padding=True)
        _ = self.model(**inputs)

        self._freeze_all_layers(self.model)
        self._unfreeze_first_layernorm()
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
        self.trainable_vars = [v for v in self.model.variables if v.trainable]
    
    def _freeze_all_layers(self, model):
        for var in model.variables:
            var._trainable = False
    
   
    def _unfreeze_first_layernorm(self):
        # for layer in self.model.clip.vision_model.embeddings.submodules:
        #     # Unfreeze the patch embedding (Conv2D)
        #     if isinstance(layer, tf.keras.layers.Conv2D):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze patch encoder: {var.name}")
            
        #     # Unfreeze position embeddings if they exist in the embeddings
        #     if 'position_embedding' in layer.name.lower():
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze position embeddings: {var.name}")
            
        #     # Unfreeze any LayerNorm in the embeddings if they exist
        #     if isinstance(layer, tf.keras.layers.LayerNormalization):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze embedding LayerNorm: {var.name}")
        # """Unfreeze only the first LayerNorm in the first vision encoder block"""

        first_block = self.model.clip.vision_model.encoder.layers[0]
        unfrozen = 0
        for layer in first_block.submodules:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                unfrozen += 1
                break  # Stop after first LayerNorm
        if unfrozen == 0:
            print(f"Client {self.client_id}: No LayerNorm was unfrozen!")
    
    def update_local_model(self, global_weights):
        for local_var, global_w in zip(self.trainable_vars, global_weights):
            local_var.assign(global_w)
    
    def set_data(self, images, labels):
        self.train_images = images
        self.train_labels = labels
        self.num_samples = len(images)
        print(f"Client {self.client_id} received {self.num_samples} training samples")
    
    def contrastive_focal_loss(self, image_embeds, text_embeds, logit_scale, gamma=1.0, alpha=0.25):   
        logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
        logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
        logits_per_text = tf.transpose(logits_per_image)

        batch_size = tf.shape(logits_per_image)[0]
        labels = tf.range(batch_size)
        
        def stable_focal_loss(logits, labels, gamma=gamma, alpha=alpha):
            probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
            labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
            pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
            focal_weight = tf.clip_by_value(tf.pow(1. - pt, gamma), 0.0, 10.0)
            ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
            loss = alpha * focal_weight * ce       
            return loss

        loss_i2t = stable_focal_loss(logits_per_image, labels)
        loss_t2i = stable_focal_loss(logits_per_text, labels)        
        loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2
        loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0   
        return loss
    
    @tf.function
    def train_step(self, inputs):
        with tf.GradientTape() as tape:
            vision_outputs = self.model.clip.vision_model(
                inputs['pixel_values'],
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            image_embeds = self.model.clip.visual_projection(vision_outputs[1])
            image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)            
            input_shape = tf.shape(inputs['input_ids'])
            seq_length = input_shape[1]
            position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]

            text_outputs = self.model.clip.text_model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask', None),
                position_ids=position_ids,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            text_embeds = self.model.clip.text_projection(text_outputs[1])
            text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)            
            logit_scale = tf.exp(self.model.clip.logit_scale)
            loss = self.contrastive_focal_loss(
                image_embeds, text_embeds, logit_scale,
                gamma=1.0, alpha=0.25
            )

        grads = tape.gradient(loss, self.trainable_vars)
        grads = [tf.clip_by_norm(g, 1.0) if g is not None else g for g in grads]
        self.optimizer.apply_gradients(zip(grads, self.trainable_vars))
        return loss
    
    def train(self, epochs):
        train_prompts = [f"a photo of a {label}" for label in self.train_labels]
        train_inputs = self.processor(
            text=train_prompts, 
            images=self.train_images, 
            return_tensors="tf", 
            padding=True
        )
        
        dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)
        
        history = []
        for epoch in range(epochs):
            epoch_loss = 0
            for batch in dataset:
                batch_loss = self.train_step(batch)
                epoch_loss += batch_loss
            
            epoch_loss = epoch_loss.numpy() / len(dataset)
            history.append(epoch_loss)
            print(f"Client {self.client_id}, Epoch {epoch+1}, Loss: {epoch_loss:.4f}")
        
        return history
    
    def get_model_weights(self):
        return [var.numpy() for var in self.trainable_vars]


def run_federated_learning(label_names, testi, teste, epochs):
    train_dataset = tfds.load("caltech101", split="train", as_supervised=True)
    test_dataset = tfds.load("caltech101", split="test", as_supervised=True)
    
    all_test_samples = list(tfds.as_numpy(test_dataset))
    test_samples = all_test_samples[testi:teste]
    test_images = [img for img, lbl in test_samples]
    test_labels = [label_names[lbl] for img, lbl in test_samples]
    server = CLIPFederatedServer()
    clients = [CLIPFederatedClient(i) for i in range(FLAGS.NUM_CLIENTS)]
    class_samples = {}
    for img, lbl in tfds.as_numpy(train_dataset):
        lbl_idx = int(lbl)
        if lbl_idx not in class_samples:
            class_samples[lbl_idx] = []
        class_samples[lbl_idx].append((img, lbl))
    
    sample_buckets = [1, 2, 3, 4, 5, 6, 8, 10, 12, 16]  # Different possible sample counts
    client_sample_counts = []
    
    for client_idx in range(FLAGS.NUM_CLIENTS):
        client_beta = FLAGS.betaa * (0.5 + random.random())  # Between 0.5*beta and 1.5*beta
        proportions = np.random.dirichlet([client_beta] * len(label_names))
        max_index = np.argmax(proportions)
        sorted_indices = np.argsort(proportions)
        indices_without_max = [idx for idx in sorted_indices if idx != max_index]
        client_counts = np.ones(len(label_names), dtype=int)  # Start with 1 for all classes
        client_counts[max_index] = FLAGS.maxsam  # Ensure the max class has 16 samples (or maxsam)
        
        bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  # -1 because we already assigned max_samples
        for i, bucket_value in enumerate(sample_buckets[:-1]):  # Skip the last bucket (16)
            start_idx = i * bucket_size
            end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
            for idx in indices_without_max[start_idx:end_idx]:
                client_counts[idx] = bucket_value
        
        client_sample_counts.append(client_counts)
    for client_idx, client in enumerate(clients):
        client_data = []
        client_counts = client_sample_counts[client_idx]
        for class_idx, count in enumerate(client_counts):
            if class_idx in class_samples:
                available_samples = class_samples[class_idx]
                random.shuffle(available_samples)
                samples_to_take = min(count, len(available_samples))
                client_data.extend(available_samples[:samples_to_take])
        
        random.shuffle(client_data)
        client_images = [img for img, lbl in client_data]
        client_labels = [label_names[int(lbl)] for img, lbl in client_data]
        client.set_data(client_images, client_labels)
    
    print("\nClient data distribution statistics:")
    for i, client in enumerate(clients):
        class_counts = {}
        for label in client.train_labels:
            if label not in class_counts:
                class_counts[label] = 0
            class_counts[label] += 1

        print(f"Client {i}: {client.num_samples} samples")
        print(f"  Classes with samples: {len(class_counts)}/{len(label_names)}")
        min_samples = min(class_counts.values()) if class_counts else 0
        max_samples = max(class_counts.values()) if class_counts else 0
        print(f"  Samples per class: Min={min_samples}, Max={max_samples}")
        sample_count_distribution = {}
        for count in class_counts.values():
            if count not in sample_count_distribution:
                sample_count_distribution[count] = 0
            sample_count_distribution[count] += 1
        print(f"  Sample count distribution: " + 
              ", ".join([f"{count} samples: {freq} classes" 
                         for count, freq in sorted(sample_count_distribution.items())]))

        examples = list(class_counts.items())[:5]
        print(f"  Examples: " + ", ".join([f"{label}: {count}" for label, count in examples]))
    

    fl_history = []   
    for round_num in range(FLAGS.COMMUNICATION_ROUNDS):
        print(f"\n--- Round {round_num + 1}/{FLAGS.COMMUNICATION_ROUNDS} ---")
        server.distribute_model(clients)
        client_weights = []
        client_samples = []
        round_history = {"round": round_num + 1, "clients": []}
        
        for client in clients:
            print(f"\nTraining on client {client.client_id}")
            client_history = client.train(epochs)
            client_weights.append(client.get_model_weights())
            client_samples.append(client.num_samples)
            
            round_history["clients"].append({
                "client_id": client.client_id,
                "num_samples": client.num_samples,
                "history": client_history
            })

        server.update_global_model(client_weights, client_samples)
        fl_history.append(round_history)

        if (round_num + 1) % 1 == 0:  # Evaluate every round
            evaluate_global_model(server, test_images, test_labels)

    class NumpyEncoder(json.JSONEncoder):
        def default(self, obj):
            if isinstance(obj, (np.integer, np.floating, np.bool_)):
                return obj.item()
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif tf.is_tensor(obj):
                return obj.numpy().tolist()
            return super(NumpyEncoder, self).default(obj)

    with open("federated_learning_history.json", "w") as f:
        json.dump(fl_history, f, cls=NumpyEncoder)
    
    return server


def evaluate_global_model(server, test_images, test_labels):
    model = server.global_model
    processor = server.processor

    prompt_templates = [
        "a photo of a {}.","a photograph of a {}.",
        "an image of a {}.","a close-up photo of a {}.",
        "a bright photo of a {}.","a cropped photo of a {}.",
        "a close-up image of a {}.","a rendition of a {}.",
        "a picture of a {}.","a detailed photo of a {}."
    ]
    all_prompts = []
    for label in test_labels:
        for template in prompt_templates:
            all_prompts.append(template.format(label))

    all_text_features = []
    for i in range(0, len(all_prompts), FLAGS.batch_size):
        batch_prompts = all_prompts[i:i+FLAGS.batch_size]
        text_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True)

        input_shape = tf.shape(text_inputs['input_ids'])
        seq_length = input_shape[1]
        position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]

        text_outputs = model.clip.text_model(
            input_ids=text_inputs['input_ids'],
            attention_mask=text_inputs.get('attention_mask', None),
            position_ids=position_ids,  # Add this parameter
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )
        batch_text_features = model.clip.text_projection(text_outputs.pooler_output)
        batch_text_features = tf.nn.l2_normalize(batch_text_features, axis=-1)
        all_text_features.append(batch_text_features)

    all_text_features = tf.concat(all_text_features, axis=0)
    
    # Inference
    correct = 0
    for img, true_label in zip(test_images, test_labels):
        inputs = processor(images=img, return_tensors="tf", padding=True)
        image_features = model.clip.vision_model(
          inputs['pixel_values'],
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
        )
        image_embeds = model.clip.visual_projection(image_features.pooler_output)
        image_features = tf.nn.l2_normalize(image_embeds, axis=-1)
        sims = tf.matmul(image_features, all_text_features, transpose_b=True)
        class_scores = {}
        prompt_idx = 0
        for label in test_labels:
            class_scores[label] = 0
            for _ in prompt_templates:
                class_scores[label] += sims[0, prompt_idx].numpy()
                prompt_idx += 1
            class_scores[label] /= len(prompt_templates)
        pred_label = max(class_scores.items(), key=lambda x: x[1])[0]

        if pred_label == true_label:
            correct += 1
    print(f"\nAccuracy on test set: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")


def main(argv):
    label_names = [
        'accordion', 'airplane', 'anchor', 'ant', 'background_google', 'barrel', 
        'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 
        'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 
        'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 
        'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 
        'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 
        'ewer', 'face', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 
        'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', 'helicopter', 
        'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', 'lamp', 'laptop', 
        'leopard', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 'menorah', 
        'metronome', 'minaret', 'motorbike', 'nautilus', 'octopus', 'okapi', 'pagoda', 
        'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 
        'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 
        'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 
        'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly', 
        'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang'
    ]

    final_model = run_federated_learning(label_names, FLAGS.testi, FLAGS.teste, FLAGS.LOCAL_EPOCHS)
    
    print("\nFederated Learning completed!")


if __name__ == "__main__":
  app.run(main)