import logging
import os
import torch
import argparse
import math
import time

from Loss_Aware_Sims import *
from Settings import *

# GPU device will be set by batch runner
if "CUDA_VISIBLE_DEVICES" not in os.environ:
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

class FL_Proc:
    def __init__(self, configs):
        print("[Stage 1/6] Initialize the FL_Proc Class...")
        # Data and model names
        self.DataName = configs["dname"]
        self.ModelName = configs["mname"]
        # number of total clients and participants in each round
        self.NClients = configs["nclients"]
        self.PClients = configs["pclients"]
        # Whether the data is IID (Independent and Identically Distributed)
        self.IsIID = configs["isIID"]
        # A parameter related to the non-IID data distribution
        self.Alpha = configs["alpha"]
        # Whether to use data augmentation
        self.Aug = configs["aug"]
        # Maximum number of training iterations
        self.MaxIter = configs["iters"]
        # Learning rate
        self.LR = configs["learning_rate"]
        # Whether to normalize the data
        self.Normal = configs["normal"]
        # Optimizer to be used (e.g., "SGD")
        self.Optmzer = configs["optimizer"]
        # Whether to fix the learning rate during training
        self.FixLR = configs["fixlr"]
        # Weight decay (for regularization)
        self.WDecay = configs["wdecay"]
        # Whether to shuffle the data
        self.DShuffle = configs["data_shuffle"]
        # Batch size for training
        self.BatchSize = configs["batch_size"]
        # Number of epochs per client training
        self.Epoch = configs["epoch"]
        # Whether to use global learning rate
        self.GlobalLR = configs["global_lr"]
        # Server optimizer (e.g., "Adam")
        self.SOM = configs["server_optim"]
        # Placeholder for participant ratios (if needed)
        self.Ratios = {}
        # Whether to randomly vary the number of participants
        self.RandNum = configs["rand_num"]
        # Whether to use model compression
        self.CPR = configs["compression"]
        # Number of iterations between logging
        self.LogStep = configs["log_step"]
        # Compression configuration
        self.CompConfig = configs.get("topk", configs.get("comp_config", {
            "granularity": "element",
            "rule": "magnitude", 
            "ratio": 0.20,
            "mode": "global",
            "include_bias": True
        }))
        
        # Learning rate decay parameters (client-side only)
        self.ClientDecayStep = configs.get("client_decay_step", 10)
        self.ClientDecayRate = configs.get("client_decay_rate", 0.9)

        # LR schedule configuration (optional)
        self.LRSchedule = configs.get("lr_schedule", {})

        print("[Stage 2/6] Load the global model...")
        # Load the global model (e.g., "alex" for cifar10)
        self.GModel = load_Model(self.ModelName, self.DataName)

        self.Server = None
        self.Clients = {}                       # Empty dictionary - clients will be created on demand
        self.ClientLoaders = None               # Placeholder for data loaders for clients
        self.TrainLoader = None                     
        self.TestLoader = None                  # Placeholder for global test data loader
        self.updateIDs = []                     # List to store IDs of clients whose updates will be considered
        for i in range(self.PClients):
            self.updateIDs.append(i)  # Add clients to update list

        # Random selection of clients for training
        self.Selection = RandomGet(self.NClients)
        self.TrainRound = 0  # Initialize training round counter

        # Set up logging to a file
        rank_tag = ""
        if self.CompConfig.get('granularity') == 'low_rank':
            rank_mode = self.CompConfig.get('rank_mode', 'ratio')
            if rank_mode == 'fixed':
                rank_tag = f"_rank_fixed{self.CompConfig.get('fixed_rank', 0)}"
            elif rank_mode == 'mixed':
                rank_tag = f"_rank_{self.CompConfig.get('min_rank', 1)}to{self.CompConfig.get('max_rank', 32)}"
            else:
                rank_tag = "_rank_ratio"

        comp_tag = "nocomp"
        if self.CPR:
            comp_tag = f"time_comp_{self.CompConfig['granularity']}_rule_{self.CompConfig['rule']}_mode_{self.CompConfig['mode']}_ratio{self.CompConfig['ratio']}{rank_tag}_calib{self.CompConfig.get('calib_batches', 0)}_{self.CompConfig.get('batch_selection', 'none')}"

        base_components = [
            f"dataset_{self.DataName}",
            f"model_{self.ModelName}",
            f"optim_{self.Optmzer}",
            f"lr{self.LR}",
            f"bs{self.BatchSize}",
            f"epoch{self.Epoch}",
            f"NC{self.NClients}",
            f"PC{self.PClients}",
            comp_tag
        ]
        log_file_path = f"./exp/{'_'.join(base_components)}/logs"
        os.makedirs(log_file_path, exist_ok=True)

        if self.IsIID:
            log_file = f'{log_file_path}/IID_NClient{self.NClients}_PClient{self.PClients}.log'
        else:
            log_file = f'{log_file_path}/Alpha{self.Alpha}_NClient{self.NClients}_PClient{self.PClients}.log'
        logging.basicConfig(filename=log_file, level=logging.INFO,
                            format='%(asctime)s - %(levelname)s - %(message)s')
        logging.info(f"Config: {configs}")

    def get_train_datas(self):
        print("[Stage 3/6] Get train and test dataset...")
        # Load training and testing data
        is_vit = str(self.ModelName).lower() in ("vit", "vit_b_16", "vit_b_32", "vit_tiny", "vit_small")
        print(f"Dataset: {self.DataName}, NClients: {self.NClients}, Is_ViT: {is_vit}")
        
        self.ClientLoaders, self.TrainLoader, self.TestLoader, Stat = get_loaders(
            self.DataName, self.NClients, self.IsIID, self.Alpha, self.Aug, False, False, self.Normal, self.DShuffle, self.BatchSize,
            use_vit_transforms=is_vit)

    def logging(self):
        # Logging the loss and accuracy after each round
        teloss, teaccu = self.Server.evaluate(self.TestLoader)
        log_msg = f"Iteration {self.TrainRound}, Loss: {teloss}, Accuracy: {teaccu}"
        logging.info(log_msg)
        print(f"[FL-Result] {log_msg}")

    def create_client_on_demand(self, client_id):
        return Client_Sim(
            self.ClientLoaders[client_id], self.GModel, self.LR, self.WDecay,
            self.Epoch, self.FixLR, self.Optmzer,
            compression_configs=self.CompConfig,
            decay_step=self.ClientDecayStep, decay_rate=self.ClientDecayRate
        )

    def main(self):
        # Load training and testing data
        self.get_train_datas()

        print("[Stage 4/6] Initialize the server...")
        # Initialize the server (no learning rate decay needed for server)
        self.Server = Server_Sim(
            self.TrainLoader, self.GModel, self.LR, self.WDecay, self.FixLR, self.DataName)

        print(f"[Stage 5/6] Register {self.NClients} clients...")
        # Register clients (but don't create them yet - they will be created on demand)
        for c in range(self.NClients):
            self.Selection.register_client(c)

        # List of all client IDs
        IDs = []
        for c in range(self.NClients):
            IDs.append(c)

        NumPartens = self.PClients  # Number of participants per round
        
        print("[Stage 6/6] Begin FL Training...")
        print(f"Training iteration: {self.MaxIter}, Num of participating clients: {self.PClients}")
        print(f"Top-K compression configuration: {self.CompConfig}")
        
        self.logging()              # Log the initial state

        # Initialize variables to store cumulative timing statistics
        total_calib_time = 0.0
        total_score_time = 0.0
        total_compression_time = 0.0
        total_training_time = 0.0
        total_clients_processed = 0
        # For 5-round averages
        recent_calib_times = []
        recent_score_times = []
        recent_compression_times = []
        recent_training_times = []
        recent_clients_processed = []

        for It in range(self.MaxIter):
            self.TrainRound = It + 1

            progress_percent = 100 * self.TrainRound / self.MaxIter
            print(f"[Round {self.TrainRound}/{self.MaxIter}]: {progress_percent:.1f}%")

            # ------- Round-wise LR scheduling (warmup + annealing) -------
            # If configured, set server optimizer LR for this round (used as global LR)
            if self.GlobalLR and isinstance(self.LRSchedule, dict) and (
                (self.LRSchedule.get("warmup_steps", 0) or self.LRSchedule.get("warmup_ratio", None) is not None)
                or (str(self.LRSchedule.get("anneal", "none")).lower() != "none")
            ):
                # Compute LR factor for 0-based step index
                step_idx = self.TrainRound - 1
                total_steps = max(1, int(self.LRSchedule.get("total_steps", self.MaxIter)))
                warmup_steps = int(self.LRSchedule.get("warmup_steps", 0) or 0)
                warmup_ratio = self.LRSchedule.get("warmup_ratio", None)
                if warmup_ratio is not None:
                    warmup_steps = max(0, int(total_steps * float(warmup_ratio)))
                start_factor = float(self.LRSchedule.get("start_factor", 0.0) or 0.0)
                anneal = str(self.LRSchedule.get("anneal", "none")).lower()
                min_lr_ratio = float(self.LRSchedule.get("min_lr_ratio", 0.0) or 0.0)
                step_size = int(self.LRSchedule.get("step_size", 1) or 1)
                gamma = float(self.LRSchedule.get("gamma", 0.9) or 0.9)
                default_anneal_steps = max(1, total_steps - warmup_steps)
                anneal_steps = int(self.LRSchedule.get("anneal_steps", 0) or 0) or default_anneal_steps

                def lr_factor(t: int) -> float:
                    if warmup_steps > 0 and t < warmup_steps:
                        return start_factor + (1.0 - start_factor) * (t + 1) / float(max(1, warmup_steps))
                    t_after = max(0, t - warmup_steps)
                    if anneal == "none":
                        return 1.0
                    if anneal == "cosine":
                        progress = min(1.0, t_after / float(max(1, anneal_steps)))
                        cos_term = 0.5 * (1.0 + math.cos(math.pi * progress))
                        return min_lr_ratio + (1.0 - min_lr_ratio) * cos_term
                    if anneal == "linear":
                        progress = min(1.0, t_after / float(max(1, anneal_steps)))
                        return max(min_lr_ratio, 1.0 - (1.0 - min_lr_ratio) * progress)
                    if anneal == "step":
                        if step_size <= 0:
                            return 1.0
                        k = t_after // step_size
                        return gamma ** k
                    return 1.0

                factor = lr_factor(step_idx)
                lr_round = float(self.LR) * float(factor)
                try:
                    self.Server.optimizer.param_groups[0]['lr'] = lr_round
                except Exception:
                    pass

            # Select participant
            updateIDs = self.Selection.select_participant(NumPartens)

            # ------- Local update of this iteration -------
            GlobalParms = self.Server.getParas()    # Global model
            LrNow = self.Server.getLR()             # Learning rate
            TransLens = []                          # The local sample size of each participating client
            TransParas = []                         # The compressed parameter update of each participating client
            TransVecs = []                          # The local normalizing vector for FedNova

            # Compression statistics for this round
            round_compression_stats = {'total_elements': 0, 'total_non_zero': 0, 'clients_compressed': 0}

            # Round-level timing statistics
            round_calib_time = 0.0
            round_score_time = 0.0
            round_compression_time = 0.0
            round_training_time = 0.0
            round_clients_processed = 0

            num_done_client = 0
            for ky in updateIDs:
                client = self.create_client_on_demand(ky)
                
                try:
                    # Distribute global learning rate and global parameters
                    if self.GlobalLR:
                        client.updateLR(LrNow)
                    client.updateParas(GlobalParms)

                    # Get the local normalizing vector for FedNova
                    num_done_client += 1
                    print(f"{num_done_client}-th Client begin local traning...", end='\r')
                    start_training_time = time.time()
                    Nvec = client.selftrain()
                    end_training_time = time.time()

                    if self.Optmzer == "VRL":
                        LSteps = client.local_steps
                        client.optimizer.update_delta(LSteps)

                    # Get the compressed parameter update
                    if self.CPR:
                        ParasNow = client.getCompDeltaParas()
                        compression_used = True
                    else:
                        ParasNow = client.getDeltaParas()
                        compression_used = False

                    # Collect compression statistics (but don't log per-client details)
                    if compression_used and self.TrainRound % 10 == 0:
                        total_elements = sum(p.numel() for p in ParasNow.values())
                        non_zero_elements = sum((p != 0).sum().item() for p in ParasNow.values())
                        round_compression_stats['total_elements'] += total_elements
                        round_compression_stats['total_non_zero'] += non_zero_elements
                        round_compression_stats['clients_compressed'] += 1

                    # Collect timing statistics for compression and training
                    round_training_time += (end_training_time - start_training_time)
                    if compression_used:
                        round_calib_time += client.calib_time
                        round_score_time += client.score_time
                        round_compression_time += client.total_compression_time
                    round_clients_processed += 1

                    # Return the local sample size, parameter update, and local normalizing vector for FedNova
                    TransLens.append(client.DLen)
                    TransParas.append(ParasNow)
                    TransVecs.append(Nvec)
                    
                finally:
                    del client
                    torch.cuda.empty_cache()
                    import gc
                    gc.collect()  

            # Log round-level compression summary instead of per-client details
            if self.CPR and self.TrainRound % 10 == 0 and round_compression_stats['clients_compressed'] > 0:
                avg_sparsity = 1 - (round_compression_stats['total_non_zero'] / round_compression_stats['total_elements'])
                logging.info(f"[FL Compression Summary] Round {self.TrainRound}: "
                            f"{round_compression_stats['clients_compressed']} clients, "
                            f"avg sparsity: {avg_sparsity:.4f}")

            # Update cumulative timing statistics
            total_calib_time += round_calib_time
            total_score_time += round_score_time
            total_compression_time += round_compression_time
            total_training_time += round_training_time
            total_clients_processed += round_clients_processed

            # Store for 5-round averages
            recent_calib_times.append(round_calib_time)
            recent_score_times.append(round_score_time)
            recent_compression_times.append(round_compression_time)
            recent_training_times.append(round_training_time)
            recent_clients_processed.append(round_clients_processed)
            # Keep only the last 5 rounds
            if len(recent_calib_times) > 5:
                recent_calib_times.pop(0)
                recent_score_times.pop(0)
                recent_compression_times.pop(0)
                recent_training_times.pop(0)
                recent_clients_processed.pop(0)

            # Log round-level timing statistics every round
            if round_clients_processed > 0:
                avg_calib_time = round_calib_time / round_clients_processed
                avg_score_time = round_score_time / round_clients_processed
                avg_compression_time = round_compression_time / round_clients_processed
                avg_training_time = round_training_time / round_clients_processed
                logging.info(f"[FL Timing Stats - Round {self.TrainRound}] "
                            f"{round_clients_processed} clients, "
                            f"Avg Training Time = {avg_training_time:.4f}s, "
                            f"Avg Total Compression Time = {avg_compression_time:.4f}s")

            # Log 5-round average timing statistics every 5 rounds
            if self.TrainRound % 5 == 0 and len(recent_calib_times) > 0:
                total_recent_clients = sum(recent_clients_processed)
                if total_recent_clients > 0:
                    avg_recent_calib_time = sum(recent_calib_times) / total_recent_clients
                    avg_recent_score_time = sum(recent_score_times) / total_recent_clients
                    avg_recent_compression_time = sum(recent_compression_times) / total_recent_clients
                    avg_recent_training_time = sum(recent_training_times) / total_recent_clients
                    logging.info(f"[FL Timing Stats - 5-Round Avg up to Round {self.TrainRound}] "
                                f"Total {total_recent_clients} clients over last {len(recent_calib_times)} rounds, "
                                f"Avg Training Time = {avg_recent_training_time:.4f}s, "
                                f"Avg Total Compression Time = {avg_recent_compression_time:.4f}s")

            # ------- Server Aggregation -------
            TransLens_arr = np.asarray(TransLens, dtype=np.float64)
            Lens_Weights = TransLens_arr / TransLens_arr.sum()
            
            TauEff = 0.0
            for k in range(len(TransLens)):
                TauEff += Lens_Weights[k] * TransVecs[k]

            for k in range(len(TransLens)):
                GPara = TransParas[k]
                GLen = TransLens_arr[k]
                GNvec = TauEff / (TransVecs[k])
                self.Server.recvInfo(GPara, GLen, GNvec)

            self.Server.aggParas(self.SOM)

            self.updateIDs = updateIDs

            # Logging
            if (It + 1) % self.LogStep == 0:
                first_weight = next(self.GModel.parameters()).detach().cpu().numpy()
                self.logging()

        # Log the Compression config and timing summary at the end of training
        if self.CPR:
            logging.info("FEDERAL LEARNING COMPLETED - Top-K COMPRESSION SUMMARY")
            logging.info(f"Top-K Configuration: {self.CompConfig}")
            logging.info(f"Total Training Rounds: {self.MaxIter}")
            logging.info(f"Compression Enabled: {self.CPR}")
            logging.info("[Final Summary] Client statistics are logged during training rounds")
            if total_clients_processed > 0:
                avg_calib_time = total_calib_time / total_clients_processed
                avg_score_time = total_score_time / total_clients_processed
                avg_compression_time = total_compression_time / total_clients_processed
                avg_training_time = total_training_time / total_clients_processed
                logging.info("[Compression Timing Summary]")
                logging.info(f"Total Clients Processed: {total_clients_processed}")
                logging.info(f"Total Training Time: {total_training_time:.4f}s, Avg per Client: {avg_training_time:.4f}s")
                logging.info(f"Total Compression Time: {total_compression_time:.4f}s, Avg per Client: {avg_compression_time:.4f}s")


if __name__ == '__main__':
    import sys
    import json

    parser = argparse.ArgumentParser(description='Run Loss_Aware_Train directly with simple args')
    # core task
    parser.add_argument('--dname', type=str, default='cifar10', choices=['cifar10','cifar100','mnist','fmnist'])
    parser.add_argument('--mname', type=str, default='alex', choices=['alex','resnet18','resnet34','resnet50','vgg11','vgg16','vit','vit_b_16','vit_b_32','vit_tiny','vit_small'])
    parser.add_argument('--gpu', type=str, default='0', help='GPU id(s), e.g., "0" or "0,1"')
    parser.add_argument('--nclients', type=int, default=100)
    parser.add_argument('--pclients', type=int, default=10)
    parser.add_argument('--iters', type=int, default=200)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--optimizer', type=str, default='SGD')#AdamW
    parser.add_argument('--epoch', type=int, default=4)
    # lr warmup + annealing
    parser.add_argument('--lr_warmup_steps', type=int, default=0)
    parser.add_argument('--lr_warmup_ratio', type=float, default=0.1)
    parser.add_argument('--lr_start_factor', type=float, default=0.0)
    parser.add_argument('--lr_anneal', type=str, default='cosine', choices=['none','cosine','linear','step'])
    parser.add_argument('--lr_min_ratio', type=float, default=1e-10)
    parser.add_argument('--lr_step_size', type=int, default=20)
    parser.add_argument('--lr_gamma', type=float, default=0.6)
    # compression core
    parser.add_argument('--compression', action='store_true', default=True)
    parser.add_argument('--granularity', type=str, default='element', choices=['element','low_rank'])
    parser.add_argument('--rule', type=str, default='lossaware', choices=['magnitude','lossaware','magnitude_svd','lossaware_svd'])
    parser.add_argument('--ratio', type=float, default=0.2)
    parser.add_argument('--mode', type=str, default='global', choices=['global','per-layer'])
    # rank selection
    parser.add_argument('--rank_mode', type=str, default='fixed', choices=['ratio','fixed','mixed'])
    parser.add_argument('--fixed_rank', type=int, default=6)
    parser.add_argument('--min_rank', type=int, default=1)
    parser.add_argument('--max_rank', type=int, default=32)
    # calibs
    parser.add_argument('--calib_batches', type=int, default= 4)
    parser.add_argument('--batch_selection', type=str, default='random', choices=['random', 'first_n', 'last_n'], help='Batch selection strategy for calibration data')
    # learning rate decay parameters (client-side only)
    parser.add_argument('--client_decay_step', type=int, default=10, help='Client learning rate decay interval (rounds)')
    parser.add_argument('--client_decay_rate', type=float, default=0.9, help='Client learning rate decay rate')
    parser.add_argument('--alpha', type=float, default=0.2)
    args = parser.parse_args()

    # Set GPU devices from CLI (overrides any defaults)
    if args.gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    Configs = {
        'dname': args.dname,
        'mname': args.mname,
        'isIID': False,
        'alpha': args.alpha,
        'normal': True,
        'aug': False,
        'data_shuffle': True,

        'nclients': args.nclients,
        'pclients': args.pclients,
        'iters': args.iters,
        'log_step': 1,
        'learning_rate': args.lr,
        'wdecay': 1e-4,
        'epoch': args.epoch,
        'batch_size': args.batch_size,
        'fixlr': False,
        'global_lr': True,
        'optimizer': args.optimizer,
        # Important
        'server_optim': "Yogi",
        'rand_num': False,

        'compression': bool(args.compression),
        'comp_config': {
            'granularity': args.granularity,
            'rule': args.rule,
            'calib_batches': args.calib_batches if 'lossaware' in args.rule else 0,
            'batch_selection': args.batch_selection if 'lossaware' in args.rule else 'random',
            'mode': args.mode,
            'ratio': args.ratio,
            'include_bias': True,
            'weight': 'svd_proj_energy' if args.rule=='lossaware_svd' else 'none',
            'ef_enabled': True,
            'log_layer_stats_every': 50,
            # rank selection for low_rank
            'rank_mode': args.rank_mode,
            'fixed_rank': args.fixed_rank,
            'min_rank': args.min_rank,
            'max_rank': args.max_rank,
        },
        
        # Learning rate decay parameters (client-side only)
        'client_decay_step': args.client_decay_step,
        'client_decay_rate': args.client_decay_rate,

        # Round-wise global LR schedule (optional)
        'lr_schedule': {
            'total_steps': args.iters,
            'warmup_steps': args.lr_warmup_steps,
            'warmup_ratio': args.lr_warmup_ratio,
            'start_factor': args.lr_start_factor,
            'anneal': args.lr_anneal,
            'min_lr_ratio': args.lr_min_ratio,
            'step_size': args.lr_step_size,
            'gamma': args.lr_gamma,
        },
    }

    FLSim = FL_Proc(Configs)
    FLSim.main()
