import tensorflow as tf
from transformers import TFCLIPModel, CLIPProcessor
from collections import defaultdict
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import random
import os
import copy
import json

from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS

flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('testi', 1000, 'Training the first module')
flags.DEFINE_integer('teste', 1400, 'Training the second module')
flags.DEFINE_float('betaa', 0.1, 'Training the second module')
flags.DEFINE_integer('maxsam', 16, 'Training the first module')
flags.DEFINE_integer('tek', 100, 'Training the second module')

flags.DEFINE_float('learning_rate', 1e-5, 'Training the second module')
flags.DEFINE_float('alpha', 0.25, 'Training the first module')
flags.DEFINE_float('gamma', 1.0, 'Training the second module')

flags.DEFINE_integer('NUM_CLIENTS', 2, 'Training the second module')
flags.DEFINE_integer('COMMUNICATION_ROUNDS', 100, 'Training the second module')
flags.DEFINE_integer('LOCAL_EPOCHS', 1, 'Training the second module')

SEED = 42; os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

class CLIPFederatedServer:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.model_name = model_name
        self.global_model = TFCLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
        inputs = self.processor(images=dummy_image, text=["a plane"], return_tensors="tf", padding=True)
        _ = self.global_model(**inputs)
        self._freeze_all_layers(self.global_model)
        self._unfreeze_first_layernorm()
        
        self.trainable_vars = [v for v in self.global_model.variables if v.trainable]
        print(f"Server initialized with {len(self.trainable_vars)} trainable variables:")
        for v in self.trainable_vars:
            print(f"  {v.name}, {v.shape}")
    
    
    def _freeze_all_layers(self, model):
        """Freeze all layers in the model"""
        for var in model.variables:
            var._trainable = False
    
    def _unfreeze_first_layernorm(self):
        # for layer in self.global_model.clip.vision_model.embeddings.submodules:
        #     # Unfreeze the patch embedding (Conv2D)
        #     if isinstance(layer, tf.keras.layers.Conv2D):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze patch encoder: {var.name}")
            
        #     # Unfreeze position embeddings if they exist in the embeddings
        #     if 'position_embedding' in layer.name.lower():
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze position embeddings: {var.name}")
            
        #     # Unfreeze any LayerNorm in the embeddings if they exist
        #     if isinstance(layer, tf.keras.layers.LayerNormalization):
        #         for var in layer.variables:
        #             var._trainable = True
        #             print(f"Unfroze embedding LayerNorm: {var.name}")
        first_block = self.global_model.clip.vision_model.encoder.layers[0]
        unfrozen = 0
        for layer in first_block.submodules:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze: {var.name}")
                unfrozen += 1
                break  # Stop after first LayerNorm
        if unfrozen == 0:
            print("No LayerNorm was unfrozen!")
    
    def get_model_weights(self):
        return [var.numpy() for var in self.trainable_vars]
    
    def update_global_model(self, client_weights_list, client_samples_list):
        total_samples = sum(client_samples_list)
        avg_weights = [np.zeros_like(w) for w in client_weights_list[0]]
        for i, client_weights in enumerate(client_weights_list):
            weight_ratio = client_samples_list[i] / total_samples
            for j, w in enumerate(client_weights):
                avg_weights[j] += w * weight_ratio
        for var, avg_w in zip(self.trainable_vars, avg_weights):
            var.assign(avg_w)        
        print(f"Global model updated with weights from {len(client_weights_list)} clients")
        return avg_weights
    
    def distribute_model(self, clients):
        global_weights = self.get_model_weights()
        for client in clients:
            client.update_local_model(global_weights)


class CLIPFederatedClient:
    def __init__(self, client_id, model_name="openai/clip-vit-base-patch32"):
        self.client_id = client_id
        self.model = TFCLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.train_images = []
        self.train_labels = []
        self.num_samples = 0
        dummy_image = Image.fromarray(np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8))
        inputs = self.processor(images=dummy_image, text=["a plane"], return_tensors="tf", padding=True)
        _ = self.model(**inputs)
        self._freeze_all_layers(self.model)
        self._unfreeze_first_layernorm()
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
        self.trainable_vars = [v for v in self.model.variables if v.trainable]
    
    def _freeze_all_layers(self, model):
        for var in model.variables:
            var._trainable = False
    
   
    def _unfreeze_first_layernorm(self):
        for layer in self.model.clip.vision_model.embeddings.submodules:
            if isinstance(layer, tf.keras.layers.Conv2D):
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze patch encoder: {var.name}")

            if 'position_embedding' in layer.name.lower():
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze position embeddings: {var.name}")

            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                    print(f"Unfroze embedding LayerNorm: {var.name}")
        first_block = self.model.clip.vision_model.encoder.layers[0]
        unfrozen = 0
        for layer in first_block.submodules:
            if isinstance(layer, tf.keras.layers.LayerNormalization):
                for var in layer.variables:
                    var._trainable = True
                unfrozen += 1
                break  # Stop after first LayerNorm
        if unfrozen == 0:
            print(f"Client {self.client_id}: No LayerNorm was unfrozen!")
    
    def update_local_model(self, global_weights):
        for local_var, global_w in zip(self.trainable_vars, global_weights):
            local_var.assign(global_w)
    
    def set_data(self, images, labels):
        self.train_images = images
        self.train_labels = labels
        self.num_samples = len(images)
        print(f"Client {self.client_id} received {self.num_samples} training samples")
    
    def contrastive_focal_loss(self, image_embeds, text_embeds, logit_scale, gamma=1.0, alpha=0.25):   
        logit_scale = tf.clip_by_value(logit_scale, 1.0, 100.0)
        logits_per_image = tf.matmul(image_embeds, text_embeds, transpose_b=True) * logit_scale
        logits_per_text = tf.transpose(logits_per_image)

        batch_size = tf.shape(logits_per_image)[0]
        labels = tf.range(batch_size)
        
        def stable_focal_loss(logits, labels, gamma=gamma, alpha=alpha):
            probs = tf.clip_by_value(tf.nn.softmax(logits, axis=-1), 1e-7, 1.0)
            labels_onehot = tf.one_hot(labels, depth=tf.shape(logits)[-1])
            pt = tf.reduce_sum(labels_onehot * probs, axis=-1)
            focal_weight = tf.clip_by_value(tf.pow(1. - pt, gamma), 0.0, 10.0)
            ce = -tf.math.log(tf.clip_by_value(pt, 1e-7, 1.0))
            loss = alpha * focal_weight * ce       
            return loss

        loss_i2t = stable_focal_loss(logits_per_image, labels)
        loss_t2i = stable_focal_loss(logits_per_text, labels)
        
        loss = (tf.reduce_mean(loss_i2t) + tf.reduce_mean(loss_t2i)) / 2
        loss = tf.where(tf.math.is_nan(loss), 1.0, loss)  # Replace NaNs with 1.0   
        return loss
    
    @tf.function
    def train_step(self, inputs):
        with tf.GradientTape() as tape:
            # Process vision
            vision_outputs = self.model.clip.vision_model(
                inputs['pixel_values'],
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            image_embeds = self.model.clip.visual_projection(vision_outputs[1])
            image_embeds = tf.nn.l2_normalize(image_embeds, axis=-1)
            input_shape = tf.shape(inputs['input_ids'])
            seq_length = input_shape[1]
            position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
            text_outputs = self.model.clip.text_model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask', None),
                position_ids=position_ids,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )
            text_embeds = self.model.clip.text_projection(text_outputs[1])
            text_embeds = tf.nn.l2_normalize(text_embeds, axis=-1)
            
            logit_scale = tf.exp(self.model.clip.logit_scale)
            loss = self.contrastive_focal_loss(
                image_embeds, text_embeds, logit_scale,
                gamma=FLAGS.gamma, alpha=FLAGS.alpha
            )
        
        grads = tape.gradient(loss, self.trainable_vars)
        grads = [tf.clip_by_norm(g, 1.0) if g is not None else g for g in grads]
        self.optimizer.apply_gradients(zip(grads, self.trainable_vars))
        return loss
    
    def train(self, epochs):
        train_prompts = [f"a photo of a {label}" for label in self.train_labels]
        train_inputs = self.processor(
            text=train_prompts, 
            images=self.train_images, 
            return_tensors="tf", 
            padding=True
        )        
        dataset = tf.data.Dataset.from_tensor_slices(train_inputs).batch(FLAGS.batch_size)        
        history = []
        for epoch in range(epochs):
            epoch_loss = 0
            for batch in dataset:
                batch_loss = self.train_step(batch)
                epoch_loss += batch_loss
            
            epoch_loss = epoch_loss.numpy() / len(dataset)
            history.append(epoch_loss)
            print(f"Client {self.client_id}, Epoch {epoch+1}, Loss: {epoch_loss:.4f}")
        
        return history
    
    def get_model_weights(self):
        return [var.numpy() for var in self.trainable_vars]


def load_and_prepare_fgvc_aircraft():
    try:
        train_dataset = tfds.load("fgvc_aircraft", split="train", as_supervised=True)
        test_dataset = tfds.load("fgvc_aircraft", split="test", as_supervised=True)
        ds_info = tfds.builder("fgvc_aircraft").info
        label_names = ds_info.features['label'].names
        
        return train_dataset, test_dataset, label_names
    except:
        print("FGVC Aircraft dataset not found in TensorFlow Datasets.")
        print("Using a predefined list of aircraft types.")
        label_names = [
            '707-320', '727-200', '737-200', '737-300', '737-400', '737-500', '737-600', '737-700', '737-800', '737-900',
            '747-100', '747-200', '747-300', '747-400', '757-200', '757-300', '767-200', '767-300', '767-400',
            '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320', 'A321', 'A330-200', 'A330-300',
            'A340-200', 'A340-300', 'A340-500', 'A340-600', 'A380', 'ATR-42', 'ATR-72', 'BAe 146-200',
            'BAe 146-300', 'BAe 125', 'Beechcraft 1900', 'Boeing 717', 'CRJ-200', 'CRJ-700', 'CRJ-900',
            'Cessna 172', 'Cessna 208', 'Cessna 525', 'Cessna 560', 'Challenger 600', 'DC-10', 'DC-9-30',
            'DHC-8-100', 'DHC-8-300', 'DR-400', 'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'Embraer Legacy 600',
            'Eurofighter Typhoon', 'F-16', 'F/A-18', 'Falcon 2000', 'Falcon 900', 'Fokker 100', 'Fokker 50', 'Fokker 70',
            'Global Express', 'Gulfstream IV', 'Gulfstream V', 'Hawk T1', 'Il-76', 'L-1011', 'MD-11', 'MD-80', 'MD-87',
            'MD-90', 'Metroliner', 'PA-28', 'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134',
            'Tu-154', 'Yak-42'
        ]
        train_dataset = None
        test_dataset = None        
        print("Note: You need to install the FGVC Aircraft dataset or use a custom dataset loader.")        
        return train_dataset, test_dataset, label_names


def run_federated_learning(label_names, testi, teste, epochs):
    train_dataset, test_dataset, label_names = load_and_prepare_fgvc_aircraft()    
    if train_dataset is None or test_dataset is None:
        print("Error: Could not load the FGVC Aircraft dataset.")
        print("Please ensure the dataset is installed or use a custom loader.")
        return None
    all_test_samples = list(tfds.as_numpy(test_dataset))
    test_samples = all_test_samples[testi:teste]
    test_images = [img for img, lbl in test_samples]
    test_labels = [label_names[lbl] for img, lbl in test_samples]
    server = CLIPFederatedServer()
    clients = [CLIPFederatedClient(i) for i in range(FLAGS.NUM_CLIENTS)]
    class_samples = {}
    for img, lbl in tfds.as_numpy(train_dataset):
        lbl_idx = int(lbl)
        if lbl_idx not in class_samples:
            class_samples[lbl_idx] = []
        class_samples[lbl_idx].append((img, lbl))
    sample_buckets = [1, 2, 3, 4, 5, 6, 8, 10, 12, 16]  # Different possible sample counts
    client_sample_counts = []
    
    for client_idx in range(FLAGS.NUM_CLIENTS):
        client_beta = FLAGS.betaa * (0.5 + random.random())  # Between 0.5*beta and 1.5*beta
        proportions = np.random.dirichlet([client_beta] * len(label_names))
        max_index = np.argmax(proportions)
        sorted_indices = np.argsort(proportions)
        indices_without_max = [idx for idx in sorted_indices if idx != max_index]
        client_counts = np.ones(len(label_names), dtype=int)  # Start with 1 for all classes
        client_counts[max_index] = FLAGS.maxsam  # Ensure the max class has 16 samples (or maxsam)
        
        bucket_size = len(indices_without_max) // (len(sample_buckets) - 1)  # -1 because we already assigned max_samples
        for i, bucket_value in enumerate(sample_buckets[:-1]):  # Skip the last bucket (16)
            start_idx = i * bucket_size
            end_idx = (i + 1) * bucket_size if i < len(sample_buckets) - 2 else len(indices_without_max)
            for idx in indices_without_max[start_idx:end_idx]:
                client_counts[idx] = bucket_value
        
        client_sample_counts.append(client_counts)
    for client_idx, client in enumerate(clients):
        client_data = []
        client_counts = client_sample_counts[client_idx]
        for class_idx, count in enumerate(client_counts):
            if class_idx in class_samples:
                available_samples = class_samples[class_idx]
                random.shuffle(available_samples)
                samples_to_take = min(count, len(available_samples))
                client_data.extend(available_samples[:samples_to_take])
        random.shuffle(client_data)
        client_images = [img for img, lbl in client_data]
        client_labels = [label_names[int(lbl)] for img, lbl in client_data]
        client.set_data(client_images, client_labels)
    print("\nClient data distribution statistics:")
    for i, client in enumerate(clients):
        class_counts = {}
        for label in client.train_labels:
            if label not in class_counts:
                class_counts[label] = 0
            class_counts[label] += 1
        print(f"Client {i}: {client.num_samples} samples")
        print(f"  Classes with samples: {len(class_counts)}/{len(label_names)}")
        min_samples = min(class_counts.values()) if class_counts else 0
        max_samples = max(class_counts.values()) if class_counts else 0
        print(f"  Samples per class: Min={min_samples}, Max={max_samples}")
        sample_count_distribution = {}
        for count in class_counts.values():
            if count not in sample_count_distribution:
                sample_count_distribution[count] = 0
            sample_count_distribution[count] += 1
        print(f"  Sample count distribution: " + 
              ", ".join([f"{count} samples: {freq} classes" 
                         for count, freq in sorted(sample_count_distribution.items())]))
        examples = list(class_counts.items())[:5]
        print(f"  Examples: " + ", ".join([f"{label}: {count}" for label, count in examples]))
    
    # Federated Learning Process
    fl_history = []   
    for round_num in range(FLAGS.COMMUNICATION_ROUNDS):
        print(f"\n--- Round {round_num + 1}/{FLAGS.COMMUNICATION_ROUNDS} ---")
        server.distribute_model(clients)        
        # Train on each client
        client_weights = []
        client_samples = []
        round_history = {"round": round_num + 1, "clients": []}        
        for client in clients:
            print(f"\nTraining on client {client.client_id}")
            client_history = client.train(epochs)            
            client_weights.append(client.get_model_weights())
            client_samples.append(client.num_samples)            
            round_history["clients"].append({
                "client_id": client.client_id,
                "num_samples": client.num_samples,
                "history": client_history
            })
        server.update_global_model(client_weights, client_samples)
        fl_history.append(round_history)
        if (round_num + 1) % 1 == 0:  # Evaluate every round
            evaluate_global_model(server, test_images, test_labels)

    class NumpyEncoder(json.JSONEncoder):
        def default(self, obj):
            if isinstance(obj, (np.integer, np.floating, np.bool_)):
                return obj.item()
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif tf.is_tensor(obj):
                return obj.numpy().tolist()
            return super(NumpyEncoder, self).default(obj)
    with open("federated_learning_history.json", "w") as f:
        json.dump(fl_history, f, cls=NumpyEncoder)
    
    return server


def evaluate_global_model(server, test_images, test_labels):
    model = server.global_model
    processor = server.processor
    prompt_templates = [
        "a photo of a {}.","a photograph of a {}.","an image of a {}.",
        "a picture of a {}.","a side view of a {}.","a front view of a {}.",
        "an airplane of type {}.","an aircraft model {}.",
        "a commercial {} aircraft.","a military {} aircraft."
    ]
    all_prompts = []
    for label in test_labels:
        for template in prompt_templates:
            all_prompts.append(template.format(label))
    all_text_features = []
    for i in range(0, len(all_prompts), FLAGS.batch_size):
        batch_prompts = all_prompts[i:i+FLAGS.batch_size]
        text_inputs = processor(text=batch_prompts, return_tensors="tf", padding=True)
        input_shape = tf.shape(text_inputs['input_ids'])
        seq_length = input_shape[1]
        position_ids = tf.range(0, seq_length, dtype=tf.int32)[tf.newaxis, :]
        text_outputs = model.clip.text_model(
            input_ids=text_inputs['input_ids'],
            attention_mask=text_inputs.get('attention_mask', None),
            position_ids=position_ids,  # Add this parameter
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )
        batch_text_features = model.clip.text_projection(text_outputs.pooler_output)
        batch_text_features = tf.nn.l2_normalize(batch_text_features, axis=-1)
        all_text_features.append(batch_text_features)

    all_text_features = tf.concat(all_text_features, axis=0)
    
    # Inference
    correct = 0
    for img, true_label in zip(test_images, test_labels):
        inputs = processor(images=img, return_tensors="tf", padding=True)
        image_features = model.clip.vision_model(
          inputs['pixel_values'],
          output_attentions=False,
          output_hidden_states=False,
          return_dict=True
        )
        image_embeds = model.clip.visual_projection(image_features.pooler_output)
        image_features = tf.nn.l2_normalize(image_embeds, axis=-1)
        sims = tf.matmul(image_features, all_text_features, transpose_b=True)
        class_scores = {}
        prompt_idx = 0
        for label in test_labels:
            class_scores[label] = 0
            for _ in prompt_templates:
                class_scores[label] += sims[0, prompt_idx].numpy()
                prompt_idx += 1
            class_scores[label] /= len(prompt_templates)
        pred_label = max(class_scores.items(), key=lambda x: x[1])[0]

        if pred_label == true_label:
            correct += 1
    print(f"\nAccuracy on test set: {correct}/{len(test_images)} = {correct/len(test_images):.2%}")


def main(argv):
    final_model = run_federated_learning(None, FLAGS.testi, FLAGS.teste, FLAGS.LOCAL_EPOCHS)    
    print("\nFederated Learning completed!")


if __name__ == "__main__":
  app.run(main)