#iclr

import copy
import torch
import logging
from datasets.data_loading import get_test_loader_fed
from conf import get_num_classes  # Ensure this is properly defined in 'conf'
from utils.registry import ADAPTATION_REGISTRY
import wandb
wandb.login(key="Input your WANDB API here")
wandb.init(project="abc", name="def")
from utils.bn_layers import BalancedRobustBN2dV5, BalancedRobustBN2dEMA, BalancedRobustBN2dV6, BalancedRobustBN1dV5


    
import logging
import os
import time
from datetime import datetime, timedelta

import numpy as np
import torch
import random


import methods
from conf import cfg, ckpt_path_to_domain_seq, get_num_classes, load_cfg_from_args
from datasets.data_loading import get_test_loader
from models.model import get_model
from utils.eval_utils import eval_domain_dict, get_accuracy
#from utils.misc import print_memory_info
from utils.registry import ADAPTATION_REGISTRY
from utils.extra import *

from FedMethod.pfedgraph_group import FedGraphServer
from FedMethod.fedamp_group import FedAMPServer
from FedMethod.fedavg_group import FedAvgServer
from FedMethod.fedavgm_group import FedAvgMServer
from FedMethod.fedgradsim_group import FedGradSimServer
from FedMethod.fedbnstat_group import FedBNStatServer


import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix, f1_score



logger = logging.getLogger(__name__)

class Client:
    def __init__(self, model, setting, domain_sequence, severities, cfg, model_preprocess, device, c_id):
        # Setup test-time adaptation method
        available_adaptations = ADAPTATION_REGISTRY.registered_names()
        assert cfg.MODEL.ADAPTATION in available_adaptations, \
            f"The adaptation '{cfg.MODEL.ADAPTATION}' is not supported! Choose from: {available_adaptations}"
        self.num_classes=get_num_classes(cfg.CORRUPTION.DATASET)
        self.model = ADAPTATION_REGISTRY.get(cfg.MODEL.ADAPTATION)(cfg=cfg, model=model, num_classes=self.num_classes)
        print(f"Successfully prepared test-time adaptation method: {cfg.MODEL.ADAPTATION}")

        #print(self.model.aux_model)

        self.setting = setting
        self.domain_sequence = domain_sequence
        self.severities = severities
        self.cfg = cfg
        self.model_preprocess = model_preprocess
        self.device = device
        self.current_domain_idx = 0
        self.current_step = 0
        self.accuracies = []
        self.dataloader = None
        self.dataloader_iter = None
        self.client_id = c_id
        self.global_state_dict = None
        self.sample_num = 200
        self.device = device
        self.grad_direction = None
        self.bn_signature = None
        
    
    def set_dataloader(self, dataloader):
        self.dataloader = dataloader
        self.dataloader_iter = iter(dataloader)



    def _get_dataloader(self):
        if self.setting == "mixed":
            domain_name = "mixed"
            self.setting = "mixed_domains"
        else:
            domain_name = self.domain_sequence[self.current_domain_idx]
        severity = self.severities[0]
       
        print(self.setting)

        
        return get_test_loader_fed(
            setting=self.setting,
            adaptation=self.cfg.MODEL.ADAPTATION,
            dataset_name=self.cfg.CORRUPTION.DATASET,
            preprocess=self.model_preprocess,
            data_root_dir=self.cfg.DATA_DIR,
            domain_name=domain_name,
            domain_names_all=self.domain_sequence,
            severity=severity,
            num_examples=self.cfg.CORRUPTION.NUM_EX,
            rng_seed=random.randint(0, 1000000),  # Generate a new random seed
            use_clip=self.cfg.MODEL.USE_CLIP,
            n_views=self.cfg.TEST.N_AUGMENTATIONS,
            delta_dirichlet=self.cfg.TEST.DELTA_DIRICHLET,
            batch_size=self.cfg.TEST.BATCH_SIZE,
            shuffle=False,
            workers=min(self.cfg.TEST.NUM_WORKERS, os.cpu_count()),
            cfg = self.cfg
        )

    def process_batch(self, global_state_dict):
        self.model.to(self.device)
        self.global_state_dict = global_state_dict    
        with torch.no_grad():
            """Processes a single batch of data."""
            self.current_step = self.current_step + 1
            
            # self.model.eval()  # Ensure the model is in evaluation mode
            # #print(self.current_domain_idx)
            try:
                batch = next(self.dataloader_iter)
            except StopIteration:
                return None  
            
            vlm_flag = False
            if vlm_flag:
                imgs, labels = batch[0], batch[1].long().to(self.device)
                imgs = [img.to(device) for img in imgs]
                
            else:

                imgs, labels = batch[0].to(self.device), batch[1].long().to(self.device)  # Convert labels to Long dtype
                
            # Count occurrences of each class (0-9)
            class_counts = torch.bincount(labels, minlength=self.num_classes)

            # Convert to list
            class_frequencies = class_counts.tolist()

            #print(class_frequencies)
            if cfg.MODEL.ADAPTATION == 'bbn_tta':
                outputs, grad_direction = self.model(imgs,labels,self.global_state_dict)
            else:
                outputs = self.model(imgs,labels)
                grad_direction = None
            
            
            #outputs = self.model([img.to(device) for img in imgs]) if isinstance(imgs, list) else model(imgs.to(device))
                
            batch_accuracy = (outputs.argmax(1) == labels).float().mean().item()
        #print(batch_accuracy)
        self.accuracies.append(batch_accuracy)
        

        wandb.log({f"Client {id(self)} {self.setting} Batch Accuracy": batch_accuracy})

        # # Free up GPU memory
        self.model.to("cpu")
        torch.cuda.empty_cache()
        
        return class_frequencies, self.model.state_dict(), batch_accuracy, False, (labels, outputs), grad_direction 


    def extract_bn_stats(self):
        stats = {}
        for name, module in self.model.named_modules():
            if hasattr(module, 'global_mean') and hasattr(module, 'global_var'):
                stats[name] = {
                    'mean': module.global_mean.clone().detach().flatten(),
                    'var': module.global_var.clone().detach().flatten()
                }
        return stats


    def overall_accuracy(self):
        """Calculates the overall accuracy after all steps."""
        overall_acc = sum(self.accuracies) / len(self.accuracies) if self.accuracies else 0.0
        wandb.log({f"Client {id(self)} Overall Accuracy": overall_acc})
        return overall_acc


from sklearn.metrics import f1_score  # <<< add this at the top of your file


def run_federated_tta(clients, cfg, domain_seq_indices):
    domain_names_all = cfg.CORRUPTION.TYPE
    num_classes = get_num_classes(cfg.CORRUPTION.DATASET)
    global_model = copy.deepcopy(clients[0].model).to('cpu')
    global_state_dict = global_model.state_dict()
    cumulative_class_frequencies = {client: [0] * num_classes for client in clients}
    num_clients = len(clients)
    num_shifts = domain_seq_indices.shape[1]
    severities = clients[0].severities

    client_class_correct = {i: [0]*num_classes for i in range(num_clients)}
    client_class_total = {i: [0]*num_classes for i in range(num_clients)}
    global_class_correct = [0] * num_classes
    global_class_total = [0] * num_classes
    domain_class_correct = {d: [0]*num_classes for d in domain_names_all}
    domain_class_total = {d: [0]*num_classes for d in domain_names_all}
    domain_client_accuracy = {d: [0]*num_clients for d in domain_names_all}
    domain_client_counts = {d: [0]*num_clients for d in domain_names_all}
    domain_conf_matrices = {d: np.zeros((num_classes, num_classes), dtype=int) for d in domain_names_all}
    imbalance_history = {i: [] for i in range(num_clients)}
    global_class_freq_timeline = []
    global_class_freq_timeline_each_round = []
    round_accuracies = []
    grad_directions = []

    # <<< NEW: store predictions/labels for F1
    client_preds_overall = {i: [] for i in range(num_clients)}
    client_labels_overall = {i: [] for i in range(num_clients)}

    def get_server(cfg, clients, global_state_dict):
        if cfg.fed.fed_tech == "fedavg" or cfg.fed.fed_tech == "fedprox":
            return FedAvgServer(clients)
        elif cfg.fed.fed_tech == "fedavgm":
            return FedAvgMServer(clients, global_state_dict)
        elif cfg.fed.fed_tech == "pfedgraph":
            return FedGraphServer(clients, global_state_dict)
        elif cfg.fed.fed_tech == "fedamp":
            return FedAMPServer(clients, global_state_dict)
        elif cfg.fed.fed_tech == "fedamp":
            return FedAMPServer(clients, global_state_dict)
        # elif cfg.fed.fed_tech == "fedgradsim":
        #     return FedGradSimServer(clients, global_state_dict)
        elif cfg.fed.fed_tech == "fedbnstat":
            return FedBNStatServer(clients, global_state_dict)

    server = get_server(cfg, clients, global_state_dict)

    full_dataloader = [dict() for _ in range(num_clients)]
    domain_set = set(idx for row in domain_seq_indices for idx in row)
    for domain_idx in domain_set:
        domain_name = domain_names_all[domain_idx]
        for severity in severities:
            test_dataloaders = get_test_loader_fed(
                setting=cfg.SETTING,
                adaptation=cfg.MODEL.ADAPTATION,
                dataset_name=cfg.CORRUPTION.DATASET,
                preprocess=clients[0].model_preprocess,
                data_root_dir=cfg.DATA_DIR,
                domain_name=domain_name,
                domain_names_all=domain_names_all,
                severity=severity,
                num_examples=cfg.CORRUPTION.NUM_EX,
                rng_seed=cfg.RNG_SEED,
                use_clip=cfg.MODEL.USE_CLIP,
                n_views=cfg.TEST.N_AUGMENTATIONS,
                delta_dirichlet=cfg.TEST.DELTA_DIRICHLET,
                batch_size=cfg.TEST.BATCH_SIZE,
                shuffle=False,
                workers=min(cfg.TEST.NUM_WORKERS, os.cpu_count()),
                cfg=cfg
            )
            for client_id, dataloader in enumerate(test_dataloaders):
                full_dataloader[client_id][domain_name] = dataloader
                #print(client_id, domain_name)

    fed_round = 0
    #print(num_shifts)

    for shift_id in range(num_shifts):
        # if shift_id == 2:
        #     break
        if shift_id == 0 or "reset_each_shift" in cfg.SETTING:
            for client in clients:
                try:
                    client.model.reset()
                    logger.info("resetting model")
                except AttributeError:
                    logger.warning("not resetting model")

        for client_id, client in enumerate(clients):
            domain_idx = domain_seq_indices[client_id, shift_id]
            domain_name = domain_names_all[domain_idx]
            
            dataloader = full_dataloader[client_id][domain_name]
            client.set_dataloader(dataloader)

        while True:
            all_exhausted = True
            client_state_dicts = []
            batch_accuracies = []

            # <<< NEW: per-round storage for F1
            client_preds_round = {i: [] for i in range(num_clients)}
            client_labels_round = {i: [] for i in range(num_clients)}

            for client in clients:
                result = client.process_batch(global_state_dict)
                if result is None:
                    continue
                class_freqs, model_state, batch_acc, exhausted, (labels, outputs), client.grad_direction  = result

                if not exhausted:
                    all_exhausted = False
                    client_id = clients.index(client)
                    domain_name = domain_names_all[domain_seq_indices[client_id, shift_id]]

                    preds = outputs.argmax(1).cpu()
                    true = labels.cpu()

                    for c in range(num_classes):
                        correct = ((preds == c) & (true == c)).sum().item()
                        total = (true == c).sum().item()
                        client_class_correct[client_id][c] += correct
                        client_class_total[client_id][c] += total
                        global_class_correct[c] += correct
                        global_class_total[c] += total
                        domain_class_correct[domain_name][c] += correct
                        domain_class_total[domain_name][c] += total

                    domain_client_accuracy[domain_name][client_id] += (preds == true).sum().item()
                    domain_client_counts[domain_name][client_id] += len(true)

                    domain_conf_matrices[domain_name] += confusion_matrix(
                        true.numpy(), preds.numpy(), labels=list(range(num_classes))
                    )

                    cumulative_class_frequencies[client] = [
                        cumulative_class_frequencies[client][i] + class_freqs[i] for i in range(num_classes)
                    ]
                    client_state_dicts.append(model_state)
                    batch_accuracies.append(batch_acc)

                    imbalance_history[client_id].append(class_freqs)

                    # <<< NEW: collect preds/labels for F1
                    client_preds_round[client_id].extend(preds.tolist())
                    client_labels_round[client_id].extend(true.tolist())
                    client_preds_overall[client_id].extend(preds.tolist())
                    client_labels_overall[client_id].extend(true.tolist())

            if all_exhausted:
                break

            # Calculate global class frequencies (sum over all clients)
            current_global_class_frequencies = [0] * num_classes
            for client in clients:
                for i in range(num_classes):
                    current_global_class_frequencies[i] += cumulative_class_frequencies[client][i]

            if global_class_freq_timeline:
                prev = global_class_freq_timeline[-1]
                delta_class_frequencies = [current_global_class_frequencies[i] - prev[i] for i in range(num_classes)]
            else:
                delta_class_frequencies = current_global_class_frequencies.copy()

            print(f"Class frequencies for this round (delta): {delta_class_frequencies}")
            global_class_freq_timeline.append(current_global_class_frequencies.copy())
            global_class_freq_timeline_each_round.append(delta_class_frequencies.copy())

            fed_round += 1
            if fed_round % cfg.fed.fed_interval ==0:
                if cfg.fed.fed_tech in ["fedavg", "fedprox"]:
                    global_state_dict = server.aggregate()
                elif cfg.fed.fed_tech == 'fedavgm':
                    server.aggregate(shift_id)
                elif cfg.fed.fed_tech == "pfedgraph":
                    server.aggregate()
                elif cfg.fed.fed_tech == 'fedamp':
                    server.aggregate(plot=True, round_num=round)
                elif cfg.fed.fed_tech == 'fedbnstat':
                    server.aggregate(plot=True, round_num=round)
                # elif cfg.fed.fed_tech == 'fedgradsim':
                #     server.aggregate()

            avg_batch_accuracy = sum(batch_accuracies) / len(batch_accuracies)
            round_accuracies.append(avg_batch_accuracy)
            logger.info(f"Shift {shift_id+1}/{num_shifts} | Round {fed_round} Accuracy: {avg_batch_accuracy:.2%}")
            wandb.log({"Global Average Batch Accuracy": avg_batch_accuracy})

         
            all_preds = sum(client_preds_round.values(), [])
            all_labels = sum(client_labels_round.values(), [])
            if all_labels:
                global_f1 = f1_score(all_labels, all_preds, average="macro")
                wandb.log({f"Global F1 Round {fed_round}": global_f1})
                print(f"Global | Round {fed_round} F1: {global_f1:.4f}")

    # Compute final average across all rounds
    average_accuracy_all_rounds = sum(round_accuracies) / len(round_accuracies)
    print(f"\n✅ Average Accuracy over All Rounds: {average_accuracy_all_rounds:.2%}\n")

    # <<< NEW: final global F1 across all rounds
    overall_preds = sum(client_preds_overall.values(), [])
    overall_labels = sum(client_labels_overall.values(), [])
    if overall_labels:
        final_f1 = f1_score(overall_labels, overall_preds, average="macro")
        print(f"\n✅ Final Global F1 Score over All Rounds: {final_f1:.4f}\n")
        wandb.log({"Final Global F1": final_f1})

    # Print results summary before plotting
    print_results_summary(num_clients, num_classes, domain_names_all, 
                      client_class_correct, client_class_total,
                      global_class_correct, global_class_total,
                      domain_class_correct, domain_class_total,
                      domain_client_accuracy, domain_client_counts,
                      global_class_freq_timeline)

    save_dir = "plots"
    os.makedirs(save_dir, exist_ok=True)

    if cfg.CORRUPTION.DATASET == 'cifar10_c':
        plot_round_accuracies(round_accuracies, save_dir)
        plot_client_class_accuracy(client_class_correct, client_class_total, num_clients, num_classes, save_dir)
        plot_global_class_accuracy(global_class_correct, global_class_total, save_dir)
        plot_domain_class_accuracy(domain_class_correct, domain_class_total, domain_names_all, num_classes, save_dir)
        plot_domain_client_accuracy(domain_client_accuracy, domain_client_counts, domain_names_all, num_clients, save_dir)
        plot_confusion_matrices(domain_conf_matrices, domain_names_all, num_classes, save_dir=os.path.join(save_dir, "conf_matrices"))
        
        plot_local_and_global_imbalance(global_class_freq_timeline, imbalance_history, [1], save_dir)
        plot_class_distribution_pie(global_class_freq_timeline, save_dir)
        plot_client_major_minor_accuracy(client_class_correct, client_class_total, num_clients, num_classes, save_dir)
        plot_domain_major_minor_accuracy(domain_class_correct, domain_class_total, domain_names_all, num_classes, save_dir)

def plot_round_accuracies(round_accuracies, save_dir):
    plt.figure(figsize=(7, 3))
    rounds = list(range(1, len(round_accuracies)+1))
    plt.plot(rounds, round_accuracies, marker='o', linestyle='-', color='blue')
    plt.xlabel("Federated Round")
    plt.ylabel("Average Accuracy")
    plt.title("Round-wise Average Accuracy")
    plt.ylim(0, 1.05)
    plt.grid(True)
    
    avg_acc = sum(round_accuracies) / len(round_accuracies)
    plt.axhline(y=avg_acc, color='red', linestyle='--', label=f'Average Accuracy: {avg_acc:.2%}')
    plt.legend()
    plt.tight_layout()
    
    save_path = os.path.join(save_dir, "round_wise_accuracy.png")
    plt.savefig(save_path, dpi=300)
    wandb.log({"Round-wise Accuracy": wandb.Image(save_path)})
    plt.close()


def plot_domain_major_minor_accuracy(domain_class_correct, domain_class_total, domain_names_all, num_classes, save_dir=None):
    """
    Plots a bar plot for each domain showing the accuracy of its major and minor class.
    """
    major_classes, minor_classes, major_accuracies, minor_accuracies = [], [], [], []

    for domain in domain_names_all:
        total_counts = np.array(domain_class_total[domain])
        if total_counts.sum() == 0:
            major_classes.append(None)
            minor_classes.append(None)
            major_accuracies.append(0)
            minor_accuracies.append(0)
            continue

        # Avoid zero counts for minor by masking them
        nonzero_indices = np.where(total_counts > 0)[0]
        major_class = total_counts.argmax()
        minor_class = nonzero_indices[np.argmin(total_counts[nonzero_indices])]

        major_acc = domain_class_correct[domain][major_class] / domain_class_total[domain][major_class] if domain_class_total[domain][major_class] > 0 else 0
        minor_acc = domain_class_correct[domain][minor_class] / domain_class_total[domain][minor_class] if domain_class_total[domain][minor_class] > 0 else 0

        major_classes.append(major_class)
        minor_classes.append(minor_class)
        major_accuracies.append(major_acc)
        minor_accuracies.append(minor_acc)

    # Plotting
    x = np.arange(len(domain_names_all))
    width = 0.35

    plt.figure(figsize=(14, 6))
    plt.bar(x - width/2, major_accuracies, width, label='Major Class Accuracy')
    plt.bar(x + width/2, minor_accuracies, width, label='Minor Class Accuracy')

    plt.xlabel("Domain")
    plt.ylabel("Accuracy")
    plt.title("Major and Minor Class Accuracy per Domain")
    plt.xticks(x, domain_names_all, rotation=45)
    plt.ylim(0, 1.05)
    plt.legend()
    plt.grid(axis='y', linestyle='--', alpha=0.6)

    # Annotate class IDs above bars
    for i in range(len(domain_names_all)):
        plt.text(x[i] - width/2, major_accuracies[i]+0.02, f"C{major_classes[i]}", ha='center', fontsize=9)
        plt.text(x[i] + width/2, minor_accuracies[i]+0.02, f"C{minor_classes[i]}", ha='center', fontsize=9)

    plt.tight_layout()
    save_path = os.path.join(save_dir, "domain_major_minor_class_accuracy.png") if save_dir else None
    if save_dir:
        plt.savefig(save_path, dpi=300)
        wandb.log({"domain_major_minor_class_accuracy.png": wandb.Image(save_path)})
        plt.close()
    else:
        wandb.log({"domain_major_minor_class_accuracy.png": wandb.Image(save_path)})
        
        plt.show()
        


def plot_client_major_minor_accuracy(client_class_correct, client_class_total, num_clients, num_classes, save_dir=None):
    """
    Plots a bar plot for each client showing the accuracy of its major and minor class.
    """
    major_classes, minor_classes, major_accuracies, minor_accuracies = [], [], [], []

    for i in range(num_clients):
        total_counts = np.array(client_class_total[i])
        if total_counts.sum() == 0:
            major_classes.append(None)
            minor_classes.append(None)
            major_accuracies.append(0)
            minor_accuracies.append(0)
            continue

        # Avoid zero counts for minor by masking them
        nonzero_indices = np.where(total_counts > 0)[0]
        major_class = total_counts.argmax()
        minor_class = nonzero_indices[np.argmin(total_counts[nonzero_indices])]

        major_acc = client_class_correct[i][major_class] / client_class_total[i][major_class] if client_class_total[i][major_class] > 0 else 0
        minor_acc = client_class_correct[i][minor_class] / client_class_total[i][minor_class] if client_class_total[i][minor_class] > 0 else 0

        major_classes.append(major_class)
        minor_classes.append(minor_class)
        major_accuracies.append(major_acc)
        minor_accuracies.append(minor_acc)

    # Plotting
    x = np.arange(num_clients)
    width = 0.35

    plt.figure(figsize=(14, 6))
    plt.bar(x - width/2, major_accuracies, width, label='Major Class Accuracy')
    plt.bar(x + width/2, minor_accuracies, width, label='Minor Class Accuracy')

    plt.xlabel("Client ID")
    plt.ylabel("Accuracy")
    plt.title("Major and Minor Class Accuracy per Client")
    plt.xticks(x)
    plt.ylim(0, 1.05)
    plt.legend()
    plt.grid(axis='y', linestyle='--', alpha=0.6)

    # Annotate class IDs above bars
    for i in range(num_clients):
        plt.text(x[i] - width/2, major_accuracies[i]+0.02, f"C{major_classes[i]}", ha='center', fontsize=9)
        plt.text(x[i] + width/2, minor_accuracies[i]+0.02, f"C{minor_classes[i]}", ha='center', fontsize=9)

    plt.tight_layout()
    
    save_path = os.path.join(save_dir, "client_major_minor_class_accuracy.png") if save_dir else None
    
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, "client_major_minor_class_accuracy.png")
        plt.savefig(save_path, dpi=300)
        wandb.log({"class_major_minor_class_accuracy.png": wandb.Image(save_path)})
        
        plt.close()
    else:
        wandb.log({"class_major_minor_class_accuracy.png": wandb.Image(plt)})
        
        plt.show()


def plot_class_distribution_pie(freq_timeline, save_path=None):
    """
    Plot a pie chart of class-wise ratio from the last entry of global_class_freq_timeline,
    using CIFAR-10 class names.
    
    Args:
        freq_timeline (list of list): global_class_freq_timeline (cumulative frequency per round).
        save_path (str, optional): If given, saves the plot to this path.
    """
    
    cifar10_class_names = [
        'airplane', 'automobile', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    ]

    if not freq_timeline:
        print("Error: Frequency timeline is empty.")
        return

    latest_freqs = freq_timeline[-1]
    total = sum(latest_freqs)

    # Filter out zero-frequency classes
    non_zero_indices = [i for i, val in enumerate(latest_freqs) if val > 0]
    values = [latest_freqs[i] for i in non_zero_indices]
    labels = [cifar10_class_names[i] for i in non_zero_indices]

    percentages = [v / total * 100 for v in values]

    # Plot pie chart
    plt.figure(figsize=(6, 6))
    plt.pie(percentages, labels=labels, autopct='%1.1f%%', startangle=140)
    plt.title("Class Distribution (Global Cumulative - CIFAR-10)")
    plt.tight_layout()
    save_path = os.path.join(save_path, "pie_global_ratio.png") if save_path else None
    
    if save_path:
        plt.savefig(save_path, dpi=300)
        wandb.log({"Pie Global Class Ratio": wandb.Image(save_path)})
        plt.close()
    else:
        wandb.log({"Pie Global Class Ratio": wandb.Image(save_path)})
        
        plt.show()

def print_results_summary(num_clients, num_classes, domain_names_all, 
                          client_class_correct, client_class_total,
                          global_class_correct, global_class_total,
                          domain_class_correct, domain_class_total,
                          domain_client_accuracy, domain_client_counts,
                          global_class_freq_timeline):
    print("\n========== Detailed Results Summary ==========\n")

    # Per-client class accuracy
    client_avg_accuracies = []
    for i in range(num_clients):
        print(f"\nClient {i} Class Accuracy:")
        accs = []
        for c in range(num_classes):
            total = client_class_total[i][c]
            correct = client_class_correct[i][c]
            acc = correct / total if total > 0 else 0
            accs.append(acc)
            print(f"  Class {c}: {correct}/{total} ({acc:.2%})")
        client_avg = sum(accs) / len(accs)
        client_avg_accuracies.append(client_avg)
        print(f"  → Average Accuracy for Client {i}: {client_avg:.2%}")

    overall_client_avg = sum(client_avg_accuracies) / num_clients
    print(f"\n→ Overall Average Client Accuracy: {overall_client_avg:.2%}")

    # Global class accuracy
    print("\nGlobal Class-wise Accuracy:")
    global_class_accs = []
    for c in range(num_classes):
        total = global_class_total[c]
        correct = global_class_correct[c]
        acc = correct / total if total > 0 else 0
        global_class_accs.append(acc)
        print(f"  Class {c}: {correct}/{total} ({acc:.2%})")
    global_avg_acc = sum(global_class_accs) / num_classes
    print(f"→ Average Global Class Accuracy: {global_avg_acc:.2%}")

    # Per-domain class accuracy
    for domain in domain_names_all:
        print(f"\nDomain '{domain}' Class Accuracy:")
        domain_accs = []
        for c in range(num_classes):
            total = domain_class_total[domain][c]
            correct = domain_class_correct[domain][c]
            acc = correct / total if total > 0 else 0
            domain_accs.append(acc)
            print(f"  Class {c}: {correct}/{total} ({acc:.2%})")
        domain_avg_acc = sum(domain_accs) / num_classes
        print(f"  → Average Accuracy in Domain '{domain}': {domain_avg_acc:.2%}")

    # Domain-wise client accuracy
    for domain in domain_names_all:
        print(f"\nDomain '{domain}' Client Accuracy:")
        client_accs = []
        for cid in range(num_clients):
            total = domain_client_counts[domain][cid]
            correct = domain_client_accuracy[domain][cid]
            acc = correct / total if total > 0 else 0
            client_accs.append(acc)
            print(f"  Client {cid}: {correct}/{total} ({acc:.2%})")
        domain_client_avg = sum(client_accs) / num_clients
        print(f"  → Average Client Accuracy in Domain '{domain}': {domain_client_avg:.2%}")

    # Final cumulative global class frequencies
    print("\nFinal Cumulative Global Class Frequencies:")
    last_freqs = global_class_freq_timeline[-1]
    for c, f in enumerate(last_freqs):
        print(f"  Class {c}: {f}")
    mean_freq = sum(last_freqs) / len(last_freqs)
    print(f"→ Mean Final Global Class Frequency: {mean_freq:.2f}")

    # Global imbalance over time
    print("\nGlobal Normalized Imbalance Over Time:")
    ratios = [compute_normalized_imbalance(freqs) for freqs in global_class_freq_timeline]
    for step, r in enumerate(ratios):
        print(f"  Step {step+1}: {r:.3f}")
    mean_imbalance = sum(ratios) / len(ratios)
    print(f"→ Average Normalized Imbalance: {mean_imbalance:.3f}")

    print("\n========== End of Detailed Results Summary ==========\n")




def compute_normalized_imbalance(freqs):
    non_zero = [f for f in freqs if f > 0]
    if len(non_zero) < 2:
        return 1.0
    return min(non_zero) / max(non_zero)

def plot_local_and_global_imbalance(global_timeline, local_timeline, selected_clients=None, save_path=None):
    plt.figure(figsize=(7, 3))
    save_path = os.path.join(save_path, "local_global_imbalance_ratio.png") if save_path else None
    

    # Plot Global
    global_ratios = [compute_normalized_imbalance(freqs) for freqs in global_timeline]
    plt.plot(global_ratios, label='Global', color='black', linewidth=2, linestyle='--')

    # Plot Local (selected clients)
    for client_id, freq_list in local_timeline.items():
        if selected_clients is None or client_id in selected_clients:
            local_ratios = [compute_normalized_imbalance(freqs) for freqs in freq_list]
            plt.plot(local_ratios, label=f'Client {client_id}')

    plt.xlabel("Federated Round")
    plt.ylabel("Imbalance Ratio")
    plt.title("Class Imbalance Over Time (Global and a Random Client)")
    plt.ylim(0, 1.05)
    plt.grid(True)
    plt.legend(ncol=2, fontsize='small')
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300)
        wandb.log({"Class Imbalance": wandb.Image(save_path)})
        plt.close()
    else:
        wandb.log({"Class Imbalance": wandb.Image(save_path)})
        plt.show()



import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_heatmap(data, freqs, xlabel, ylabel, title, xticks, yticks, save_path=None):
    plt.figure(figsize=(14, 6))
    sns.heatmap(data, annot=freqs, fmt="", cmap="viridis", xticklabels=xticks, yticklabels=yticks)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300)
        plt.close()
    else:
        plt.show()

def plot_client_class_accuracy(client_class_correct, client_class_total, num_clients, num_classes, save_dir=None):
    acc_matrix, freq_matrix = [], []
    for i in range(num_clients):
        acc_row, freq_row = [], []
        total_sum = sum(client_class_total[i])
        for j in range(num_classes):
            total = client_class_total[i][j]
            acc = client_class_correct[i][j] / total if total > 0 else 0
            freq = total / total_sum if total_sum > 0 else 0
            acc_row.append(acc)
            freq_row.append(f"{freq:.2f}")
        acc_matrix.append(acc_row)
        freq_matrix.append(freq_row)

    save_path = os.path.join(save_dir, "client_class_accuracy.png") if save_dir else None
    plot_heatmap(acc_matrix, freq_matrix, "Class ID", "Client ID", "Client-wise Class Accuracy",
                 list(range(num_classes)), list(range(num_clients)), save_path)

def plot_global_class_accuracy(global_class_correct, global_class_total, save_dir=None):
    cifar10_class_names = [
        'airplane', 'automobile', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    ]

    accs = [global_class_correct[i] / global_class_total[i] if global_class_total[i] > 0 else 0 for i in range(len(global_class_correct))]
    
    plt.figure(figsize=(7, 3))
    bar_width = 0.2  # thinner bars
    x = np.arange(len(global_class_correct))

    plt.bar(x, accs, width=bar_width, color='skyblue', edgecolor='black')
    plt.xlabel("Class")
    plt.ylabel("Accuracy")
    plt.title("Global Class-wise Accuracy")
    plt.xticks(x, cifar10_class_names, rotation=30, ha='right')
    plt.ylim(0, 1.05)
    plt.grid(axis='y', linestyle='--', alpha=0.6)
    plt.tight_layout()

    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, "global_class_accuracy.png")
        plt.savefig(save_path, dpi=300)
        wandb.log({"Global Class Accuracy": wandb.Image(save_path)})
        plt.close()
    else:
        wandb.log({"Global Class Accuracy": wandb.Image(save_path)})
        plt.show()

def plot_domain_class_accuracy(domain_class_correct, domain_class_total, domain_names_all, num_classes, save_dir=None):
    acc_matrix, freq_matrix = [], []
    for domain in domain_names_all:
        acc_row, freq_row = [], []
        total_sum = sum(domain_class_total[domain])
        for j in range(num_classes):
            total = domain_class_total[domain][j]
            acc = domain_class_correct[domain][j] / total if total > 0 else 0
            freq = total / total_sum if total_sum > 0 else 0
            acc_row.append(acc)
            freq_row.append(f"{freq:.2f}")
        acc_matrix.append(acc_row)
        freq_matrix.append(freq_row)

    save_path = os.path.join(save_dir, "domain_class_accuracy.png") if save_dir else None
    plot_heatmap(acc_matrix, freq_matrix, "Class ID", "Domain", "Domain-wise Class Accuracy",
                 list(range(num_classes)), domain_names_all, save_path)

def plot_domain_client_accuracy(domain_client_accuracy, domain_client_counts, domain_names_all, num_clients, save_dir=None):
    acc_matrix = []
    annot_matrix = []  # matrix for text annotations

    for domain in domain_names_all:
        row = []
        annot_row = []
        for cid in range(num_clients):
            total = domain_client_counts[domain][cid]
            acc = domain_client_accuracy[domain][cid] / total if total > 0 else 0
            row.append(acc)
            annot_row.append(f"{acc:.2f}")
        acc_matrix.append(row)
        annot_matrix.append(annot_row)

    save_path = os.path.join(save_dir, "domain_client_accuracy.png") if save_dir else None
    plot_heatmap(acc_matrix, annot_matrix, "Client ID", "Domain", "Domain-wise Client Accuracy",
                 list(range(num_clients)), domain_names_all, save_path)

def plot_confusion_matrices(domain_conf_matrices, domain_names_all, num_classes, save_dir=None):
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    for domain in domain_names_all:
        cm = domain_conf_matrices[domain]
        cm_sum = cm.sum(axis=1, keepdims=True)
        cm_sum[cm_sum == 0] = 1  # avoid divide-by-zero
        cm_normalized = cm.astype('float') / cm_sum
        cm_normalized = np.nan_to_num(cm_normalized)

        plt.figure(figsize=(8, 6))
        sns.heatmap(cm_normalized, annot=True, cmap='Blues',
                    xticklabels=list(range(num_classes)), yticklabels=list(range(num_classes)))
        plt.title(f"Confusion Matrix - {domain}")
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.tight_layout()

        if save_dir:
            save_path = os.path.join(save_dir, f"confusion_matrix_{domain}.png")
            plt.savefig(save_path, dpi=300)
            plt.close()
        else:
            plt.show()



if __name__ == "__main__":
    load_cfg_from_args('fed tta')
    # User-defined settings
    n = cfg.fed.client_num  # Total clients
    n1, n2, n3, n4 = cfg.fed.client_num_continual,cfg.fed.client_num_mixed,cfg.fed.client_num_reset, 0   # Clients for continual, mixed, reset_each_shift
    device = "cuda" if torch.cuda.is_available() else "cpu"
    num_classes = get_num_classes(dataset_name=cfg.CORRUPTION.DATASET)

    # # Initialize base model and preprocessing
    # base_model, model_preprocess = get_model(cfg, num_classes, device)

    # # Initialize models and clients
    # models = [copy.deepcopy(base_model) for _ in range(n)]
    domain_sequence = cfg.CORRUPTION.TYPE
    severities = cfg.CORRUPTION.SEVERITY
    

    #cfg.MODEL.ADAPTATION = 'source'
    c_id = 0  # Global variable

    print(cfg.CORRUPTION.DATASET)

    clients = []
    for setting in zip(
        
        ["continual"] * n1 + ["mixed"] * n2 + ["reset_each_shift"] * n3 + ["correlated"] * n4
    ):
        base_model, model_preprocess = get_model(cfg, num_classes, device)

        clients.append(Client(copy.deepcopy(base_model).to('cpu'), setting, domain_sequence, severities, cfg, model_preprocess, device, c_id))
        c_id += 1  # Increment the global client ID

    
    print(cfg.TEST.BATCH_SIZE)

    
    
   
     # setup wandb logging
     
    wandb.run.name = "fed-" + cfg.MODEL.ADAPTATION + "-" + cfg.fed.fed_tech + "-" + cfg.CORRUPTION.DATASET + "-dirchlet " +  str(cfg.TEST.DELTA_DIRICHLET)

    information = "10 correlated"
    wandb.run.name += "-" + information

    # add current bangladesh time to the run name
    now = datetime.now()
    new_time = now + timedelta(hours=11)
    wandb.run.name += "-" + new_time.strftime("%Y-%m-%d-%H-%M-%S")

    wandb.config.update(cfg)
    
    run_federated_tta(clients, cfg, domain_seq_indices=arr[:cfg.fed.client_num])

