import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
from collections import defaultdict
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import random
import os
import copy
import json

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Start index for test set evaluation')
flags.DEFINE_integer('teste', 1400, 'End index for test set evaluation')
flags.DEFINE_float('betaa', 0.1, 'Dirichlet distribution parameter for non-IID data partitioning')
flags.DEFINE_integer('maxsam', 16, 'Maximum samples per class for dominant classes')
flags.DEFINE_integer('tek', 100, 'Training parameter')

flags.DEFINE_float('learning_rate', 1e-5, 'Learning rate for client optimization')
flags.DEFINE_float('alpha', 0.25, 'Alpha parameter for focal loss')
flags.DEFINE_float('gamma', 1.0, 'Gamma parameter for focal loss')

flags.DEFINE_integer('NUM_CLIENTS', 2, 'Number of federated clients')
flags.DEFINE_integer('COMMUNICATION_ROUNDS', 100, 'Number of communication rounds')
flags.DEFINE_integer('LOCAL_EPOCHS', 1, 'Number of local training epochs per round')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

class CLIPFederatedServer:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model_name = model_name
        self.global_model = TFCLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        
        # Run a forward pass to build weights
        dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
        inputs = self.processor(images=dummy_image, text=["a photo"], return_tensors="tf", padding=True)
        _ = self.global_model(**inputs)
        
        # Freeze all layers
        self._freeze_all_layers(self.global_model)
        
        # Unfreeze only the first LayerNorm in the first vision encoder block
        self._unfreeze_first_layernorm()
        
        # Get the list of trainable variables (for averaging later)
        self.trainable_vars = [v for v in self.global_model.variables if v.trainable]
        print(f"Server initialized with {len(self.trainable_vars)} trainable variables:")
        for v in self.trainable_vars:
            print(f"  {v.name}, {v.shape}")
    
    
    def _freeze_all_layers(self, model):
        """Freeze all layers in the model"""
        for var in model.variables:
            var._trainable = False
    
    def _unfreeze_first_layernorm(self):
        # for layer in self.global_model.clip.vision_model.embeddings.submodules:
        #     # Unfreeze the patch embedding (Conv2D)
        #     if isinstance(layer, tf.keras.layers.Conv2D):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze patch encoder: {var.name}")
            
        #     # Unfreeze position embeddings if they exist in the embeddings
        #     if 'position_embedding' in layer.name.lower():
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze position embeddings: {var.name}")
            
        #     # Unfreeze any LayerNorm in the embeddings if they exist
        #     if isinstance(layer, tf.keras.layers.LayerNormalization):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze embedding LayerNorm: {var.name}")

        """Unfreeze only the first LayerNorm in the first vision encoder block"""
        first_block = self.global_model.clip.vision_model.encoder.layers[0]
        unfrozen = 0
        for layer in first_block.submodules:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze: {var.name}")
                unfrozen += 1
                break  # Stop after first LayerNorm
        if unfrozen == 0:
            print("No LayerNorm was unfrozen!")
    
    def get_model_weights(self):
        return [var.numpy() for var in self.trainable_vars]
    
    def update_global_model(self, client_weights_list, client_samples_list):
        total_samples = sum(client_samples_list)
        avg_weights = [np.zeros_like(w) for w in client_weights_list[0]]
        for i, client_weights in enumerate(client_weights_list):
            weight_ratio = client_samples_list[i] / total_samples
            for j, w in enumerate(client_weights):
                avg_weights[j] += w * weight_ratio
        
        for var, avg_w in zip(self.trainable_vars, avg_weights):
            var.assign(avg_w)
        
        print(f"Global model updated with weights from {len(client_weights_list)} clients")
        return avg_weights
    
    def distribute_model(self, clients):
        global_weights = self.get_model_weights()
        for client in clients:
            client.update_local_model(global_weights)


class CLIPFederatedClient:
    def __init__(self, client_id, model_name="openai/clip-vit-base-patch32"):
        self.client_id = client_id
        self.model = TFCLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.train_images = []
        self.train_labels = []
        self.num_samples = 0
        
        dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
        inputs = self.processor(images=dummy_image, text=["a photo"], return_tensors="tf", padding=True)
        _ = self.model(**inputs)
        self._freeze_all_layers(self.model)
        self._unfreeze_first_layernorm()
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
        self.trainable_vars = [v for v in self.model.variables if v.trainable]
    
    def _freeze_all_layers(self, model):
        """Freeze all layers in the model"""
        for var in model.variables:
            var._trainable = False
    
   
    # def _unfreeze_first_layernorm(self):
    #     for layer in self.model.clip.vision_model.embeddings.submodules:
    #         # Unfreeze the patch embedding (Conv2D)
    #         if isinstance(layer, tf.keras.layers.Conv2D):
    #             for var in layer.variables:
    #                 var._trainable = True
    #                 print(f"Unfroze patch encoder: {var.name}")
            
    #         # Unfreeze position embeddings if they exist in the embeddings
    #         if 'position_embedding' in layer.name.lower():
    #             for var in layer.variables:
    #                 var._trainable = True
    #                 print(f"Unfroze position embeddings: {var.name}")
            
    #         # Unfreeze any LayerNorm in the embeddings if they exist
    #         if isinstance(layer, tf.keras.layers.LayerNormalization):
    #             for var in layer.variables:
    #                 var._trainable = True
    #                 print(f"Unfroze embedding LayerNorm: {var.name}")
    #     """Unfreeze only the first LayerNorm in the first vision encoder block"""

        first_block = self.model.clip.vision_model.encoder.layers[0]
        unfrozen = 0
        for layer in first_block.submodules:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                unfrozen += 1
                break  # Stop after first LayerNorm
        if unfrozen == 0:
            print(f"Client {self.client_id}: No LayerNorm was unfrozen!")
    
    def update_local_model(self, global_weights):
        for local_var, global_w in zip(self.trainable_vars, global_weights):
            local_var.assign(global_w)
    
    def set_data(self, images, labels):
        self.train_images = images
        self.train_labels = labels
        self.num_samples = len(images)
        print(f"Client {self.client_id} received {self.num_samples} training samples")
    
    def contrastive_focal_loss(self, image_embeds, text_embeds, logit_scale, gamma=1.0, alpha=0.25):   
        logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
        logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
        logits_per_text = tf.transpose(logits_per_image)

        batch_size = tf.shape(logits_per_image)[0]
        labels = tf.range(batch_size)
        
        def stable_focal_loss(logits, labels, gamma=gamma, alpha=alpha):
            probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
            labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
            pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
            focal_weight = tf.clip_by_value(tf.pow(1. - pt, gamma), 0.0, 10.0)
            ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
            loss = alpha * focal_weight * ce       
            return loss

        loss_i2t = stable_focal_loss(logits_per_image, labels)
        loss_t2i = stable_focal_loss(logits_per_text, labels)
        
        loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2        
        loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0   
        return loss
    
    @tf.function
    def train_step(self, inputs):
        with tf.GradientTape() as tape:
            vision_outputs = self.model.clip.vision_model(
                inputs['pixel_values'],
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            image_embeds = self.model.clip.visual_projection(vision_outputs[1])
            image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)
            input_shape = tf.shape(inputs['input_ids'])
            seq_length = input_shape[1]
            position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
            
            text_outputs = self.model.clip.text_model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask', None),
                position_ids=position_ids,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )

            text_embeds = self.model.clip.text_projection(text_outputs[1])
            text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)
            
            logit_scale = tf.exp(self.model.clip.logit_scale)
            loss = self.contrastive_focal_loss(
                image_embeds, text_embeds, logit_scale,
                gamma=FLAGS.gamma, alpha=FLAGS.alpha
            )

        grads = tape.gradient(loss, self.trainable_vars)
        grads = [tf.clip_by_norm(g, 1.0) if g is not None else g for g in grads]
        self.optimizer.apply_gradients(zip(grads, self.trainable_vars))
        return loss
    
    def train(self, epochs):
        train_prompts = [f"a photo of a {label}" for label in self.train_labels]
        train_inputs = self.processor(
            text=train_prompts, 
            images=self.train_images, 
            return_tensors="tf", 
            padding=True
        )
        
        dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)
        
        history = []
        for epoch in range(epochs):
            epoch_loss = 0
            for batch in dataset:
                batch_loss = self.train_step(batch)
                epoch_loss += batch_loss
            
            epoch_loss = epoch_loss.numpy() / len(dataset)
            history.append(epoch_loss)
            print(f"Client {self.client_id}, Epoch {epoch+1}, Loss: {epoch_loss:.4f}")
        
        return history
    
    def get_model_weights(self):
        return [var.numpy() for var in self.trainable_vars]


def run_federated_learning(label_names, testi, teste, epochs):

    train_dataset = tfds.load("cifar100", split="train", as_supervised=True)
    test_dataset = tfds.load("cifar100", split="test", as_supervised=True)
    
    def preprocess_tf_to_pil(image, label):
        image = tf.cast(image, tf.uint8)
        image_np = image.numpy()
        pil_image = Image.fromarray(image_np)
        # Resize the image to match CLIP's expected input size (224x224)
        pil_image = pil_image.resize((224, 224), Image.BICUBIC)
        return pil_image, label
    all_test_samples = []
    for img, lbl in tfds.as_numpy(test_dataset.take(teste)):
        pil_img = Image.fromarray(img)
        pil_img = pil_img.resize((224, 224), Image.BICUBIC)
        all_test_samples.append((pil_img, lbl))
    
    test_samples = all_test_samples[testi:teste]
    test_images = [img for img, lbl in test_samples]
    test_labels = [label_names[lbl] for img, lbl in test_samples]
    server = CLIPFederatedServer()

    clients = [CLIPFederatedClient(i) for i in range(FLAGS.NUM_CLIENTS)]   
    class_samples = {}
    
    for img, lbl in tfds.as_numpy(train_dataset):
        lbl_idx = int(lbl)
        if lbl_idx not in class_samples:
            class_samples[lbl_idx] = []

        pil_img = Image.fromarray(img)
        pil_img = pil_img.resize((224, 224), Image.BICUBIC)
        class_samples[lbl_idx].append((pil_img, lbl))
    
    sample_buckets = [1, 2, 3, 4, 5, 6, 8, 10, 12, 16]  # Different possible sample counts
    client_sample_counts = []
    
    for client_idx in range(FLAGS.NUM_CLIENTS):
        client_beta = FLAGS.betaa * (0.5 + random.random())  # Between 0.5*beta and 1.5*beta

        proportions = np.random.dirichlet([client_beta] * len(label_names))
        max_index = np.argmax(proportions)
        
        sorted_indices = np.argsort(proportions)
        indices_without_max = [idx for idx in sorted_indices if idx != max_index]

        client_counts = np.ones(len(label_names), dtype=int)  # Start with 1 for all classes
        client_counts[max_index] = FLAGS.maxsam  # Ensure the max class has maxsam samples
        
        bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  # -1 because we already assigned max_samples
        for i, bucket_value in enumerate(sample_buckets[:-1]):  # Skip the last bucket (16)
            start_idx = i * bucket_size
            end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
            for idx in indices_without_max[start_idx:end_idx]:
                client_counts[idx] = bucket_value
        
        client_sample_counts.append(client_counts)

    for client_idx, client in enumerate(clients):
        client_data = []
        client_counts = client_sample_counts[client_idx]
        
        for class_idx, count in enumerate(client_counts):
            if class_idx in class_samples:
                available_samples = class_samples[class_idx]
                random.shuffle(available_samples)
                samples_to_take = min(count, len(available_samples))
                client_data.extend(available_samples[:samples_to_take])
        
        random.shuffle(client_data)

        client_images = [img for img, lbl in client_data]
        client_labels = [label_names[int(lbl)] for img, lbl in client_data]

        client.set_data(client_images, client_labels)
    
    print("\nClient data distribution statistics:")
    for i, client in enumerate(clients):
        class_counts = {}
        for label in client.train_labels:
            if label not in class_counts:
                class_counts[label] = 0
            class_counts[label] += 1
        
        print(f"Client {i}: {client.num_samples} samples")
        print(f"  Classes with samples: {len(class_counts)}/{len(label_names)}")
        min_samples = min(class_counts.values()) if class_counts else 0
        max_samples = max(class_counts.values()) if class_counts else 0
        print(f"  Samples per class: Min={min_samples}, Max={max_samples}")

        sample_count_distribution = {}
        for count in class_counts.values():
            if count not in sample_count_distribution:
                sample_count_distribution[count] = 0
            sample_count_distribution[count] += 1

        print(f"  Sample count distribution: " + 
              ", ".join([f"{count} samples: {freq} classes" 
                         for count, freq in sorted(sample_count_distribution.items())]))
        examples = list(class_counts.items())[:5]
        print(f"  Examples: " + ", ".join([f"{label}: {count}" for label, count in examples]))
    
    # Federated Learning Process
    fl_history = []    
    for round_num in range(FLAGS.COMMUNICATION_ROUNDS):
        print(f"\n--- Round {round_num + 1}/{FLAGS.COMMUNICATION_ROUNDS} ---")
        
        server.distribute_model(clients)

        client_weights = []
        client_samples = []
        round_history = {"round": round_num + 1, "clients": []}
        
        for client in clients:
            print(f"\nTraining on client {client.client_id}")
            client_history = client.train(epochs)

            client_weights.append(client.get_model_weights())
            client_samples.append(client.num_samples)
            
            round_history["clients"].append({
                "client_id": client.client_id,
                "num_samples": client.num_samples,
                "history": client_history
            })
        server.update_global_model(client_weights, client_samples)
        fl_history.append(round_history)

        if (round_num + 1) % 1 == 0:  # Evaluate every round
            evaluate_global_model(server, test_images, test_labels)

    class NumpyEncoder(json.JSONEncoder):
        def default(self, obj):
            if isinstance(obj, (np.integer, np.floating, np.bool_)):
                return obj.item()
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif tf.is_tensor(obj):
                return obj.numpy().tolist()
            return super(NumpyEncoder, self).default(obj)

    with open("federated_learning_history.json", "w") as f:
        json.dump(fl_history, f, cls=NumpyEncoder)
    
    return server


def evaluate_global_model(server, test_images, test_labels):
    """Evaluate the global model on test data"""
    model = server.global_model
    processor = server.processor
    
    prompt_templates = [
        "a photo of a {}.","a photograph of a {}.",
        "an image of a {}.","a close-up photo of a {}.",
        "a bright photo of a {}.","a cropped photo of a {}.",
        "a close-up image of a {}.","a rendition of a {}.",
        "an example of a {}.","a picture of a {}."
    ]
    all_prompts = []
    for label in test_labels:
        for template in prompt_templates:
            all_prompts.append(template.format(label))
    all_text_features = []
    for i in range(0, len(all_prompts), FLAGS.batch_size):
        batch_prompts = all_prompts[i:i+FLAGS.batch_size]
        text_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True)
        input_shape = tf.shape(text_inputs['input_ids'])
        seq_length = input_shape[1]
        position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
        text_outputs = model.clip.text_model(
            input_ids=text_inputs['input_ids'],
            attention_mask=text_inputs.get('attention_mask', None),
            position_ids=position_ids,  # Add this parameter
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )
        batch_text_features = model.clip.text_projection(text_outputs.pooler_output)
        batch_text_features = tf.nn.l2_normalize(batch_text_features, axis=-1)
        all_text_features.append(batch_text_features)

    all_text_features = tf.concat(all_text_features, axis=0)

    correct = 0
    for img, true_label in zip(test_images, test_labels):
        inputs = processor(images=img, return_tensors="tf", padding=True)
        image_features = model.clip.vision_model(
          inputs['pixel_values'],
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
        )
        image_embeds = model.clip.visual_projection(image_features.pooler_output)
        image_features = tf.nn.l2_normalize(image_embeds, axis=-1)
        sims = tf.matmul(image_features, all_text_features, transpose_b=True)
        class_scores = {}
        prompt_idx = 0
        for label in test_labels:
            class_scores[label] = 0
            for _ in prompt_templates:
                class_scores[label] += sims[0, prompt_idx].numpy()
                prompt_idx += 1
            class_scores[label] /= len(prompt_templates)

        pred_label = max(class_scores.items(), key=lambda x: x[1])[0]

        if pred_label == true_label:
            correct += 1
    print(f"\nAccuracy on test set: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")


def main(argv):
    label_names = [
        'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 
        'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 
        'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 
        'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 
        'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 
        'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 
        'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 
        'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 
        'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 
        'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
    ]
    
    # Run federated learning
    final_model = run_federated_learning(label_names, FLAGS.testi, FLAGS.teste, FLAGS.LOCAL_EPOCHS)
    
    print("\nFederated Learning completed!")


if __name__ == "__main__":
  app.run(main)