import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
from collections import defaultdict
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import random
import os
import copy
import json

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Starting index for test samples')
flags.DEFINE_integer('teste', 1400, 'Ending index for test samples')
flags.DEFINE_float('betaa', 0.1, 'Dirichlet parameter for non-IID data distribution')
flags.DEFINE_integer('maxsam', 16, 'Maximum samples per class for a client')
flags.DEFINE_integer('tek', 1, 'Training parameter')

flags.DEFINE_float('learning_rate', 1e-5, 'Learning rate for client optimization')
flags.DEFINE_float('alpha', 0.25, 'Alpha parameter for focal loss')
flags.DEFINE_float('gamma', 1.0, 'Gamma parameter for focal loss')

flags.DEFINE_integer('NUM_CLIENTS', 2, 'Number of federated clients')
flags.DEFINE_integer('COMMUNICATION_ROUNDS', 15, 'Number of communication rounds')
flags.DEFINE_integer('LOCAL_EPOCHS', 1, 'Number of local epochs per round')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

class CLIPFederatedServer:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model_name = model_name
        self.global_model = TFCLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
        inputs = self.processor(images=dummy_image, text=["a building"], return_tensors="tf", padding=True)
        _ = self.global_model(**inputs)
        self._freeze_all_layers(self.global_model)
        self._unfreeze_first_layernorm()
        self.trainable_vars = [v for v in self.global_model.variables if v.trainable]
        print(f"Server initialized with {len(self.trainable_vars)} trainable variables:")
        for v in self.trainable_vars:
            print(f"  {v.name}, {v.shape}")
    
    
    def _freeze_all_layers(self, model):
        """Freeze all layers in the model"""
        for var in model.variables:
            var._trainable = False
    
    def _unfreeze_first_layernorm(self):
        # for layer in self.global_model.clip.vision_model.embeddings.submodules:
        #     # Unfreeze the patch embedding (Conv2D)
        #     if isinstance(layer, tf.keras.layers.Conv2D):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze patch encoder: {var.name}")
            
        #     # Unfreeze position embeddings if they exist in the embeddings
        #     if 'position_embedding' in layer.name.lower():
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze position embeddings: {var.name}")
            
        #     # Unfreeze any LayerNorm in the embeddings if they exist
        #     if isinstance(layer, tf.keras.layers.LayerNormalization):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze embedding LayerNorm: {var.name}")

        """Unfreeze only the first LayerNorm in the first vision encoder block"""
        first_block = self.global_model.clip.vision_model.encoder.layers[0]
        unfrozen = 0
        for layer in first_block.submodules:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze: {var.name}")
                unfrozen += 1
                break  # Stop after first LayerNorm
        if unfrozen == 0:
            print("No LayerNorm was unfrozen!")
    
    def get_model_weights(self):
        return [var.numpy() for var in self.trainable_vars]
    
    def update_global_model(self, client_weights_list, client_samples_list):
        total_samples = sum(client_samples_list)
        avg_weights = [np.zeros_like(w) for w in client_weights_list[0]]
        for i, client_weights in enumerate(client_weights_list):
            weight_ratio = client_samples_list[i] / total_samples
            for j, w in enumerate(client_weights):
                avg_weights[j] += w * weight_ratio
        for var, avg_w in zip(self.trainable_vars, avg_weights):
            var.assign(avg_w)
        
        print(f"Global model updated with weights from {len(client_weights_list)} clients")
        return avg_weights
    
    def distribute_model(self, clients):
        """Distribute the current global model to all clients"""
        global_weights = self.get_model_weights()
        for client in clients:
            client.update_local_model(global_weights)


class CLIPFederatedClient:
    def __init__(self, client_id, model_name="openai/clip-vit-base-patch32"):
        self.client_id = client_id
        self.model = TFCLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.train_images = []
        self.train_labels = []
        self.num_samples = 0
        dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
        inputs = self.processor(images=dummy_image, text=["a building"], return_tensors="tf", padding=True)
        _ = self.model(**inputs)

        self._freeze_all_layers(self.model)
        self._unfreeze_first_layernorm()
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
        self.trainable_vars = [v for v in self.model.variables if v.trainable]
    
    def _freeze_all_layers(self, model):
        for var in model.variables:
            var._trainable = False
    
   
    def _unfreeze_first_layernorm(self):
        # for layer in self.model.clip.vision_model.embeddings.submodules:
        #     # Unfreeze the patch embedding (Conv2D)
        #     if isinstance(layer, tf.keras.layers.Conv2D):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze patch encoder: {var.name}")
            
        #     # Unfreeze position embeddings if they exist in the embeddings
        #     if 'position_embedding' in layer.name.lower():
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze position embeddings: {var.name}")
            
        #     # Unfreeze any LayerNorm in the embeddings if they exist
        #     if isinstance(layer, tf.keras.layers.LayerNormalization):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze embedding LayerNorm: {var.name}")
        # """Unfreeze only the first LayerNorm in the first vision encoder block"""

        first_block = self.model.clip.vision_model.encoder.layers[0]
        unfrozen = 0
        for layer in first_block.submodules:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                unfrozen += 1
                break  # Stop after first LayerNorm
        if unfrozen == 0:
            print(f"Client {self.client_id}: No LayerNorm was unfrozen!")
    
    def update_local_model(self, global_weights):
        for local_var, global_w in zip(self.trainable_vars, global_weights):
            local_var.assign(global_w)
    
    def set_data(self, images, labels):
        self.train_images = images
        self.train_labels = labels
        self.num_samples = len(images)
        print(f"Client {self.client_id} received {self.num_samples} training samples")
    
    def contrastive_focal_loss(self, image_embeds, text_embeds, logit_scale, gamma=1.0, alpha=0.25):   
        logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
        logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
        logits_per_text = tf.transpose(logits_per_image)
        batch_size = tf.shape(logits_per_image)[0]
        labels = tf.range(batch_size)        
        def stable_focal_loss(logits, labels, gamma=gamma, alpha=alpha):
            probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
            labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
            pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
            focal_weight = tf.clip_by_value(tf.pow(1. - pt, gamma), 0.0, 10.0)
            ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
            loss = alpha * focal_weight * ce       
            return loss

        loss_i2t = stable_focal_loss(logits_per_image, labels)
        loss_t2i = stable_focal_loss(logits_per_text, labels)        
        loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2
        loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0   
        return loss
    
    @tf.function
    def train_step(self, inputs):
        with tf.GradientTape() as tape:
            vision_outputs = self.model.clip.vision_model(
                inputs['pixel_values'],
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            image_embeds = self.model.clip.visual_projection(vision_outputs[1])
            image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)
            input_shape = tf.shape(inputs['input_ids'])
            seq_length = input_shape[1]
            position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
            text_outputs = self.model.clip.text_model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask', None),
                position_ids=position_ids,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            text_embeds = self.model.clip.text_projection(text_outputs[1])
            text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)
            
            logit_scale = tf.exp(self.model.clip.logit_scale)
            loss = self.contrastive_focal_loss(
                image_embeds, text_embeds, logit_scale,
                gamma=FLAGS.gamma, alpha=FLAGS.alpha
            )
        
        grads = tape.gradient(loss, self.trainable_vars)
        grads = [tf.clip_by_norm(g, 1.0) if g is not None else g for g in grads]
        self.optimizer.apply_gradients(zip(grads, self.trainable_vars))
        return loss
    
    def train(self, epochs):
        train_prompts = [f"a photo of a {label}" for label in self.train_labels]
        train_inputs = self.processor(
            text=train_prompts, 
            images=self.train_images, 
            return_tensors="tf", 
            padding=True
        )
        
        dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)
        
        history = []
        for epoch in range(epochs):
            epoch_loss = 0
            for batch in dataset:
                batch_loss = self.train_step(batch)
                epoch_loss += batch_loss
            
            epoch_loss = epoch_loss.numpy() / len(dataset)
            history.append(epoch_loss)
            print(f"Client {self.client_id}, Epoch {epoch+1}, Loss: {epoch_loss:.4f}")
        
        return history
    
    def get_model_weights(self):
        """Extract weights from trainable variables"""
        return [var.numpy() for var in self.trainable_vars]


def load_sun_dataset():
    try:
        train_dataset = tfds.load("sun397/standard-part1", split="train", as_supervised=True)
        test_dataset = tfds.load("sun397/standard-part1", split="test", as_supervised=True)
        dataset_info = tfds.builder("sun397").info
        label_names = dataset_info.features["label"].names
        
    except (tfds.core.registered.DatasetNotFoundError, ValueError):
        print("SUN397 dataset not found in tfds. Using a subset of scene categories.")
        label_names = [
            'kitchen', 'bedroom', 'living_room', 'bathroom', 'dining_room', 
            'office', 'classroom', 'conference_room', 'auditorium', 'theater',
            'restaurant', 'supermarket', 'bakery', 'coffee_shop', 'bar',
            'airport_terminal', 'subway_station', 'train_station', 'bus_station', 'harbor',
            'street', 'highway', 'field', 'forest', 'mountain',
            'beach', 'ocean', 'lake', 'river', 'valley',
            'desert', 'snow', 'garden', 'park', 'playground',
            'campus', 'church', 'mosque', 'temple', 'synagogue',
            'hospital', 'library', 'bookstore', 'clothing_store', 'toy_store',
            'garage', 'gas_station', 'parking_lot', 'construction_site', 'bridge'
        ]

        random_images = []
        random_labels = []
        for i in range(1000):  # Create 1000 dummy training samples
            img = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
            label = np.random.randint(0, len(label_names))
            random_images.append(img)
            random_labels.append(label)
        
        train_dataset = tf.data.Dataset.from_tensor_slices((random_images, random_labels))
        test_images = []
        test_labels = []
        for i in range(500):  # Create 500 dummy test samples
            img = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
            label = np.random.randint(0, len(label_names))
            test_images.append(img)
            test_labels.append(label)
        
        test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))       
        print("Created simulated SUN dataset with 1000 training and 500 test samples")
    
    return train_dataset, test_dataset, label_names


def run_federated_learning(label_names, testi, teste, epochs):
    train_dataset, test_dataset, label_names = load_sun_dataset()
    all_test_samples = list(tfds.as_numpy(test_dataset))
    test_samples = all_test_samples[testi:min(teste, len(all_test_samples))]
    test_images = [img for img, lbl in test_samples]
    test_labels = [label_names[lbl] for img, lbl in test_samples]
    server = CLIPFederatedServer()
    clients = [CLIPFederatedClient(i) for i in range(FLAGS.NUM_CLIENTS)]
    class_samples = {}
    for img, lbl in tfds.as_numpy(train_dataset):
        lbl_idx = int(lbl)
        if lbl_idx not in class_samples:
            class_samples[lbl_idx] = []
        class_samples[lbl_idx].append((img, lbl))
    
    sample_buckets = [1, 2, 3, 4, 5, 6, 8, 10, 12, 16]  # Different possible sample counts
    client_sample_counts = []
    
    for client_idx in range(FLAGS.NUM_CLIENTS):
        client_beta = FLAGS.betaa * (0.5 + random.random())  # Between 0.5*beta and 1.5*beta
        proportions = np.random.dirichlet([client_beta] * len(label_names))
        max_index = np.argmax(proportions)
        sorted_indices = np.argsort(proportions)
        indices_without_max = [idx for idx in sorted_indices if idx != max_index]
        client_counts = np.ones(len(label_names), dtype=int)  # Start with 1 for all classes
        client_counts[max_index] = FLAGS.maxsam  # Ensure the max class has 16 samples (or maxsam)
        
        bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  # -1 because we already assigned max_samples
        for i, bucket_value in enumerate(sample_buckets[:-1]):  # Skip the last bucket (16)
            start_idx = i * bucket_size
            end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
            for idx in indices_without_max[start_idx:end_idx]:
                client_counts[idx] = bucket_value
        
        client_sample_counts.append(client_counts)
    
    for client_idx, client in enumerate(clients):
        client_data = []
        client_counts = client_sample_counts[client_idx]
        for class_idx, count in enumerate(client_counts):
            if class_idx in class_samples:
                available_samples = class_samples[class_idx]
                random.shuffle(available_samples)
                samples_to_take = min(count, len(available_samples))
                client_data.extend(available_samples[:samples_to_take])
        
        random.shuffle(client_data)
        client_images = [img for img, lbl in client_data]
        client_labels = [label_names[int(lbl)] for img, lbl in client_data]
        client.set_data(client_images, client_labels)

    print("\nClient data distribution statistics:")
    for i, client in enumerate(clients):
        class_counts = {}
        for label in client.train_labels:
            if label not in class_counts:
                class_counts[label] = 0
            class_counts[label] += 1
        print(f"Client {i}: {client.num_samples} samples")
        print(f"  Classes with samples: {len(class_counts)}/{len(label_names)}")

        min_samples = min(class_counts.values()) if class_counts else 0
        max_samples = max(class_counts.values()) if class_counts else 0
        print(f"  Samples per class: Min={min_samples}, Max={max_samples}")

        sample_count_distribution = {}
        for count in class_counts.values():
            if count not in sample_count_distribution:
                sample_count_distribution[count] = 0
            sample_count_distribution[count] += 1

        print(f"  Sample count distribution: " + 
              ", ".join([f"{count} samples: {freq} classes" 
                         for count, freq in sorted(sample_count_distribution.items())]))

        examples = list(class_counts.items())[:5]
        print(f"  Examples: " + ", ".join([f"{label}: {count}" for label, count in examples]))

    fl_history = []    
    for round_num in range(FLAGS.COMMUNICATION_ROUNDS):
        print(f"\n--- Round {round_num + 1}/{FLAGS.COMMUNICATION_ROUNDS} ---")

        server.distribute_model(clients)
        client_weights = []
        client_samples = []
        round_history = {"round": round_num + 1, "clients": []}
        
        for client in clients:
            print(f"\nTraining on client {client.client_id}")
            client_history = client.train(epochs)
            client_weights.append(client.get_model_weights())
            client_samples.append(client.num_samples)
            
            round_history["clients"].append({
                "client_id": client.client_id,
                "num_samples": client.num_samples,
                "history": client_history
            })

        server.update_global_model(client_weights, client_samples)
        fl_history.append(round_history)

        if (round_num + 1) % 1 == 0:  # Evaluate every round
            evaluate_global_model(server, test_images, test_labels)
    
    class NumpyEncoder(json.JSONEncoder):
        def default(self, obj):
            if isinstance(obj, (np.integer, np.floating, np.bool_)):
                return obj.item()
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif tf.is_tensor(obj):
                return obj.numpy().tolist()
            return super(NumpyEncoder, self).default(obj)

    with open("federated_learning_history_sun.json", "w") as f:
        json.dump(fl_history, f, cls=NumpyEncoder)
    
    return server


def evaluate_global_model(server, test_images, test_labels):
    """Evaluate the global model on test data"""
    if len(test_images) == 0 or len(test_labels) == 0:
        print("No test data available for evaluation")
        return
        
    model = server.global_model
    processor = server.processor
    
    prompt_templates = [
        "a photo of a {}.","a photograph of a {}.",
        "an image of a {}.","a picture of a {}.",
        "a scene showing a {}.","a view of a {}.",
        "a scene of a {}.","a landscape with a {}.",
        "an outdoor scene of a {}.","an indoor scene of a {}."
    ]
    unique_labels = list(set(test_labels))

    all_prompts = []
    for label in unique_labels:
        for template in prompt_templates:
            all_prompts.append(template.format(label))

    if len(all_prompts) == 0:
        print("No prompts generated for evaluation")
        return

    all_text_features = []
    for i in range(0, len(all_prompts), FLAGS.batch_size):
        batch_prompts = all_prompts[i:i+FLAGS.batch_size]
        if len(batch_prompts) == 0:
            continue
            
        text_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True)
        input_shape = tf.shape(text_inputs['input_ids'])
        seq_length = input_shape[1]
        position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
        text_outputs = model.clip.text_model(
            input_ids=text_inputs['input_ids'],
            attention_mask=text_inputs.get('attention_mask', None),
            position_ids=position_ids,  # Add this parameter
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )
        batch_text_features = model.clip.text_projection(text_outputs.pooler_output)
        batch_text_features = tf.nn.l2_normalize(batch_text_features, axis=-1)
        all_text_features.append(batch_text_features)

    if len(all_text_features) == 0:
        print("No text features generated for evaluation")
        return
        
    all_text_features = tf.concat(all_text_features, axis=0)
    correct = 0
    for img, true_label in zip(test_images, test_labels):
        inputs = processor(images=img, return_tensors="tf", padding=True)
        image_features = model.clip.vision_model(
          inputs['pixel_values'],
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
        )
        image_embeds = model.clip.visual_projection(image_features.pooler_output)
        image_features = tf.nn.l2_normalize(image_embeds, axis=-1)
        sims = tf.matmul(image_features, all_text_features, transpose_b=True)
        
        class_scores = {}
        prompt_idx = 0
        for label in unique_labels:
            class_scores[label] = 0
            for _ in prompt_templates:
                class_scores[label] += sims[0, prompt_idx].numpy()
                prompt_idx += 1
            class_scores[label] /= len(prompt_templates)

        pred_label = max(class_scores.items(), key=lambda x: x[1])[0]

        if pred_label == true_label:
            correct += 1
    print(f"\nAccuracy on test set: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")


def main(argv):
    final_model = run_federated_learning(None, FLAGS.testi, FLAGS.teste, FLAGS.LOCAL_EPOCHS)
    
    print("\nFederated Learning completed!")


if __name__ == "__main__":
  app.run(main)