import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
from collections import defaultdict
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import random
import os
import copy
import json

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Index of first test sample to evaluate')
flags.DEFINE_integer('teste', 1400, 'Index of last test sample to evaluate')
flags.DEFINE_float('betaa', 0.1, 'Dirichlet distribution parameter for client data allocation')
flags.DEFINE_integer('maxsam', 16, 'Maximum samples per class for client data distribution')
flags.DEFINE_integer('tek', 100, 'Training parameter')

flags.DEFINE_float('learning_rate', 1e-5, 'Learning rate for client optimizers')
flags.DEFINE_float('alpha', 0.25, 'Alpha parameter for focal loss')
flags.DEFINE_float('gamma', 1.0, 'Gamma parameter for focal loss')

flags.DEFINE_integer('NUM_CLIENTS', 2, 'Number of federated clients')
flags.DEFINE_integer('COMMUNICATION_ROUNDS', 100, 'Number of communication rounds')
flags.DEFINE_integer('LOCAL_EPOCHS', 1, 'Number of local training epochs per communication round')
flags.DEFINE_string('output_dir', 'federated_food101_model', 'Directory to save the trained model')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

def preprocess_image(image, target_size=(224, 224)):
    """Resize and normalize the image for CLIP model input"""
    # Check if image is already an array, if not convert
    if not isinstance(image, np.ndarray):
        image = np.array(image)
    
    # Convert image to RGB if it's grayscale
    if len(image.shape) == 2:
        image = np.stack([image, image, image], axis=-1)
    
    # Ensure it's 3 channels (RGB)
    if image.shape[-1] == 1:
        image = np.concatenate([image, image, image], axis=-1)
    elif image.shape[-1] > 3:
        image = image[:, :, :3]  # Keep only the first 3 channels
    
    # Resize image
    pil_image = Image.fromarray(image)
    pil_image = pil_image.resize(target_size, Image.LANCZOS)
    
    # Convert back to numpy
    image = np.array(pil_image)
    
    # Normalize to CLIP expected range (if needed)
    image = image.astype(np.float32) / 255.0
    
    return image

class CLIPFederatedServer:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model_name = model_name
        self.global_model = TFCLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        
        # Run a forward pass to build weights
        dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
        inputs = self.processor(images=dummy_image, text=["a sandwich"], return_tensors="tf", padding=True)
        _ = self.global_model(**inputs)
        
        # Freeze all layers
        self._freeze_all_layers(self.global_model)
        
        # Unfreeze only the first LayerNorm in the first vision encoder block
        self._unfreeze_first_layernorm()
        
        # Get the list of trainable variables (for averaging later)
        self.trainable_vars = [v for v in self.global_model.variables if v.trainable]
        print(f"Server initialized with {len(self.trainable_vars)} trainable variables:")
        for v in self.trainable_vars:
            print(f"  {v.name}, {v.shape}")
    
    
    def _freeze_all_layers(self, model):
        """Freeze all layers in the model"""
        for var in model.variables:
            var._trainable = False
    
    def _unfreeze_first_layernorm(self):
        for layer in self.global_model.clip.vision_model.embeddings.submodules:
            # Unfreeze the patch embedding (Conv2D)
            if isinstance(layer, tf.keras.layers.Conv2D):
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze patch encoder: {var.name}")
            
            # Unfreeze position embeddings if they exist in the embeddings
            if 'position_embedding' in layer.name.lower():
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze position embeddings: {var.name}")
            
            # Unfreeze any LayerNorm in the embeddings if they exist
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze embedding LayerNorm: {var.name}")

        """Unfreeze only the first LayerNorm in the first vision encoder block"""
        first_block = self.global_model.clip.vision_model.encoder.layers[0]
        unfrozen = 0
        for layer in first_block.submodules:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze: {var.name}")
                unfrozen += 1
                break  # Stop after first LayerNorm
        if unfrozen == 0:
            print("No LayerNorm was unfrozen!")
    
    def get_model_weights(self):
        """Extract weights from trainable variables"""
        return [var.numpy() for var in self.trainable_vars]
    
    def update_global_model(self, client_weights_list, client_samples_list):
        """Update global model using weighted average of client models"""
        # Calculate the weighted average based on number of samples per client
        total_samples = sum(client_samples_list)
        
        # Initialize with zeros of the same shape
        avg_weights = [np.zeros_like(w) for w in client_weights_list[0]]
        
        # Compute weighted average
        for i, client_weights in enumerate(client_weights_list):
            weight_ratio = client_samples_list[i] / total_samples
            for j, w in enumerate(client_weights):
                avg_weights[j] += w * weight_ratio
        
        # Update global model
        for var, avg_w in zip(self.trainable_vars, avg_weights):
            var.assign(avg_w)
        
        print(f"Global model updated with weights from {len(client_weights_list)} clients")
        return avg_weights
    
    def distribute_model(self, clients):
        """Distribute the current global model to all clients"""
        global_weights = self.get_model_weights()
        for client in clients:
            client.update_local_model(global_weights)


class CLIPFederatedClient:
    def __init__(self, client_id, model_name="openai/clip-vit-base-patch32", learning_rate=1e-3):
        self.client_id = client_id
        self.model = TFCLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.train_images = []
        self.train_labels = []
        self.num_samples = 0
        
        # Run a forward pass to build weights
        dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
        inputs = self.processor(images=dummy_image, text=["a sandwich"], return_tensors="tf", padding=True)
        _ = self.model(**inputs)
        
        # Freeze all layers
        self._freeze_all_layers(self.model)
        
        # Unfreeze only the first LayerNorm in the first vision encoder block
        self._unfreeze_first_layernorm()
        
        # Set optimizer with configurable learning rate
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        self.trainable_vars = [v for v in self.model.variables if v.trainable]
    
    def _freeze_all_layers(self, model):
        """Freeze all layers in the model"""
        for var in model.variables:
            var._trainable = False
    
   
    def _unfreeze_first_layernorm(self):
        for layer in self.model.clip.vision_model.embeddings.submodules:
            # Unfreeze the patch embedding (Conv2D)
            if isinstance(layer, tf.keras.layers.Conv2D):
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze patch encoder: {var.name}")
            
            # Unfreeze position embeddings if they exist in the embeddings
            if 'position_embedding' in layer.name.lower():
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze position embeddings: {var.name}")
            
            # Unfreeze any LayerNorm in the embeddings if they exist
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze embedding LayerNorm: {var.name}")
        """Unfreeze only the first LayerNorm in the first vision encoder block"""

        first_block = self.model.clip.vision_model.encoder.layers[0]
        unfrozen = 0
        for layer in first_block.submodules:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                unfrozen += 1
                break  # Stop after first LayerNorm
        if unfrozen == 0:
            print(f"Client {self.client_id}: No LayerNorm was unfrozen!")
    
    def update_local_model(self, global_weights):
        """Update local model with global weights"""
        for local_var, global_w in zip(self.trainable_vars, global_weights):
            local_var.assign(global_w)
    
    def set_data(self, images, labels):
        """Set local training data"""
        self.train_images = images
        self.train_labels = labels
        self.num_samples = len(images)
        print(f"Client {self.client_id} received {self.num_samples} training samples")
    
    def contrastive_focal_loss(self, image_embeds, text_embeds, logit_scale, gamma=None, alpha=None):   
        # Use FLAGS values if not provided
        if gamma is None:
            gamma = FLAGS.gamma
        if alpha is None:
            alpha = FLAGS.alpha
            
        # Compute similarity matrices with clipped scaling to prevent extreme values
        logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
        logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
        logits_per_text = tf.transpose(logits_per_image)

        batch_size = tf.shape(logits_per_image)[0]
        labels = tf.range(batch_size)
        
        def stable_focal_loss(logits, labels, gamma=gamma, alpha=alpha):
            probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
            labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
            pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
            focal_weight = tf.clip_by_value(tf.pow(1. - pt, gamma), 0.0, 10.0)
            ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
            loss = alpha * focal_weight * ce       
            return loss

        loss_i2t = stable_focal_loss(logits_per_image, labels)
        loss_t2i = stable_focal_loss(logits_per_text, labels)
        
        loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2
        
        # Safety check for NaN values
        loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0   
        return loss
    
    @tf.function
    def train_step(self, inputs):
        """Single training step"""
        with tf.GradientTape() as tape:
            # Process vision
            vision_outputs = self.model.clip.vision_model(
                inputs['pixel_values'],
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            # Notice we're using vision_outputs[1] or vision_outputs.pooler_output
            image_embeds = self.model.clip.visual_projection(vision_outputs[1])
            image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)
            
            # Create position_ids for text model
            input_shape = tf.shape(inputs['input_ids'])
            seq_length = input_shape[1]
            position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
            
            # Process text
            text_outputs = self.model.clip.text_model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask', None),
                position_ids=position_ids,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            # Again, using text_outputs[1] or text_outputs.pooler_output
            text_embeds = self.model.clip.text_projection(text_outputs[1])
            text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)
            
            logit_scale = tf.exp(self.model.clip.logit_scale)
            loss = self.contrastive_focal_loss(
                image_embeds, text_embeds, logit_scale
            )
        
        # Compute gradients and apply them
        grads = tape.gradient(loss, self.trainable_vars)
        # Add gradient clipping for stability
        grads = [tf.clip_by_norm(g, 1.0) if g is not None else g for g in grads]
        self.optimizer.apply_gradients(zip(grads, self.trainable_vars))
        return loss
    
    def train(self, epochs):
        """Train the local model for several epochs"""
        train_prompts = [f"a photo of {label}" for label in self.train_labels]
        train_inputs = self.processor(
            text=train_prompts, 
            images=self.train_images, 
            return_tensors="tf", 
            padding=True
        )
        
        dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)
        
        history = []
        for epoch in range(epochs):
            epoch_loss = 0
            for batch in dataset:
                batch_loss = self.train_step(batch)
                epoch_loss += batch_loss
            
            epoch_loss = epoch_loss.numpy() / len(dataset)
            history.append(epoch_loss)
            print(f"Client {self.client_id}, Epoch {epoch+1}, Loss: {epoch_loss:.4f}")
        
        return history
    
    def get_model_weights(self):
        """Extract weights from trainable variables"""
        return [var.numpy() for var in self.trainable_vars]


def run_federated_learning(label_names, testi, teste, epochs):
    # Load Food-101 dataset
    train_dataset = tfds.load("food101", split="train", as_supervised=True)
    test_dataset = tfds.load("food101", split="validation", as_supervised=True)
    
    print(f"Food-101 dataset loaded with {len(label_names)} classes")
    
    # Prepare test data for evaluation
    all_test_samples = list(tfds.as_numpy(test_dataset))
    test_samples = all_test_samples[testi:teste]
    test_images = [preprocess_image(img) for img, lbl in test_samples]
    test_labels = [label_names[lbl] for img, lbl in test_samples]
    
    print(f"Prepared {len(test_images)} test samples for evaluation")
    
    # Initialize server
    server = CLIPFederatedServer()
    
    # Initialize clients with configured learning rate
    clients = [CLIPFederatedClient(i, learning_rate=FLAGS.learning_rate) for i in range(FLAGS.NUM_CLIENTS)]
    
    # Create non-IID data distribution based on the original approach
    # First, group samples by class
    class_samples = {}
    for img, lbl in tfds.as_numpy(train_dataset):
        lbl_idx = int(lbl)
        if lbl_idx not in class_samples:
            class_samples[lbl_idx] = []
        class_samples[lbl_idx].append((img, lbl))
    
    # Apply the original sample count determination for each client
    sample_buckets = [1, 2, 3, 4, 5, 6, 8, 10, 12, 16]  # Different possible sample counts
    
    # For each client, determine how many samples per class using Dirichlet
    client_sample_counts = []
    
    for client_idx in range(FLAGS.NUM_CLIENTS):
        # Use different beta values to create non-IID distribution
        client_beta = FLAGS.betaa * (0.5 + random.random())  # Between 0.5*beta and 1.5*beta
        
        # Draw proportions from the Dirichlet distribution for this client
        proportions = np.random.dirichlet([client_beta] * len(label_names))
        max_index = np.argmax(proportions)
        
        # Sort the proportions (excluding the max) to create groups
        sorted_indices = np.argsort(proportions)
        indices_without_max = [idx for idx in sorted_indices if idx != max_index]
        
        # Distribute classes across the buckets based on their relative proportions
        client_counts = np.ones(len(label_names), dtype=int)  # Start with 1 for all classes
        client_counts[max_index] = FLAGS.maxsam  # Ensure the max class has 16 samples (or maxsam)
        
        bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  # -1 because we already assigned max_samples
        for i, bucket_value in enumerate(sample_buckets[:-1]):  # Skip the last bucket (16)
            start_idx = i * bucket_size
            end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
            for idx in indices_without_max[start_idx:end_idx]:
                client_counts[idx] = bucket_value
        
        client_sample_counts.append(client_counts)
    
    # Collect the data for each client based on their sample counts
    for client_idx, client in enumerate(clients):
        client_data = []
        client_counts = client_sample_counts[client_idx]
        
        # For each class, take the specified number of samples
        for class_idx, count in enumerate(client_counts):
            if class_idx in class_samples:
                # Get up to 'count' samples for this class
                available_samples = class_samples[class_idx]
                # Shuffle to get random samples
                random.shuffle(available_samples)
                # Take the required number of samples (or all if fewer are available)
                samples_to_take = min(count, len(available_samples))
                client_data.extend(available_samples[:samples_to_take])
        
        # Shuffle client data
        random.shuffle(client_data)
        
        # Extract images and labels
        client_images = [preprocess_image(img) for img, lbl in client_data]
        client_labels = [label_names[int(lbl)] for img, lbl in client_data]
        
        # Set data for this client
        client.set_data(client_images, client_labels)
    
    # Print distribution statistics
    print("\nClient data distribution statistics:")
    for i, client in enumerate(clients):
        # Count samples per class for this client
        class_counts = {}
        for label in client.train_labels:
            if label not in class_counts:
                class_counts[label] = 0
            class_counts[label] += 1
        
        # Print statistics
        print(f"Client {i}: {client.num_samples} samples")
        print(f"  Classes with samples: {len(class_counts)}/{len(label_names)}")
        
        # Verify min and max samples per class
        min_samples = min(class_counts.values()) if class_counts else 0
        max_samples = max(class_counts.values()) if class_counts else 0
        print(f"  Samples per class: Min={min_samples}, Max={max_samples}")
        
        # Print distribution of sample counts
        sample_count_distribution = {}
        for count in class_counts.values():
            if count not in sample_count_distribution:
                sample_count_distribution[count] = 0
            sample_count_distribution[count] += 1
        
        # Print buckets sorted
        print(f"  Sample count distribution: " + 
              ", ".join([f"{count} samples: {freq} classes" 
                         for count, freq in sorted(sample_count_distribution.items())]))
        
        # Print a few examples of class-sample counts
        examples = list(class_counts.items())[:5]
        print(f"  Examples: " + ", ".join([f"{label}: {count}" for label, count in examples]))
    
    # Federated Learning Process
    fl_history = []
    
    for round_num in range(FLAGS.COMMUNICATION_ROUNDS):
        print(f"\n--- Round {round_num + 1}/{FLAGS.COMMUNICATION_ROUNDS} ---")
        
        # Distribute global model to all clients
        server.distribute_model(clients)
        
        # Train on each client
        client_weights = []
        client_samples = []
        round_history = {"round": round_num + 1, "clients": []}
        
        for client in clients:
            print(f"\nTraining on client {client.client_id}")
            client_history = client.train(epochs)
            
            # Store client results for averaging
            client_weights.append(client.get_model_weights())
            client_samples.append(client.num_samples)
            
            round_history["clients"].append({
                "client_id": client.client_id,
                "num_samples": client.num_samples,
                "history": client_history
            })
        
        # Update global model (weighted average)
        server.update_global_model(client_weights, client_samples)
        fl_history.append(round_history)
        
        # # Evaluate global model on test data
        # if (round_num + 1) % 1 == 0:  # Evaluate every round
        #     evaluate_global_model(server, test_images, test_labels)
    
    # Save training history
    class NumpyEncoder(json.JSONEncoder):
        def default(self, obj):
            if isinstance(obj, (np.integer, np.floating, np.bool_)):
                return obj.item()
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif tf.is_tensor(obj):
                return obj.numpy().tolist()
            return super(NumpyEncoder, self).default(obj)

    # And then use it in your code:
    with open("federated_learning_history.json", "w") as f:
        json.dump(fl_history, f, cls=NumpyEncoder)
    
    return server


def main(argv):
    # Food-101 labels
    label_names = [
        'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad',
        'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad',
        'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheesecake', 'cheese_plate',
        'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse',
        'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame',
        'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots',
        'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup',
        'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi',
        'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger',
        'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna',
        'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup',
        'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes',
        'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib',
        'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi',
        'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara',
        'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu',
        'tuna_tartare', 'waffles'
    ]
    
    print(f"Starting federated learning with Food-101 dataset")
    print(f"Parameters: batch_size={FLAGS.batch_size}, learning_rate={FLAGS.learning_rate}")
    print(f"Num clients: {FLAGS.NUM_CLIENTS}, Communication rounds: {FLAGS.COMMUNICATION_ROUNDS}")
    print(f"Local epochs: {FLAGS.LOCAL_EPOCHS}")
    print(f"Test sample range: {FLAGS.testi} to {FLAGS.teste}")
    
    # Run federated learning
    final_model = run_federated_learning(label_names, FLAGS.testi, FLAGS.teste, FLAGS.LOCAL_EPOCHS)
    
    # Save the trained model
    # model_path = save_model(final_model, FLAGS.output_dir)
    
    print(f"\nFederated Learning completed!")
    # print(f"Trained model saved to {model_path}")
    print(f"You can use the sample inference script to test the model.")
    
    return final_model


if __name__ == "__main__":
    app.run(main)