import os
import random
import numpy as np
import torch
import yaml
import math

import wandb

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler

from torch.utils.tensorboard import SummaryWriter


def over_write_args_from_dict(args, dict):
    """
    overwrite arguments acocrding to config file
    """
    for k in dict:
        setattr(args, k, dict[k])


def over_write_args_from_file(args, yml):
    """
    overwrite arguments acocrding to config file
    """
    if yml == '':
        return
    with open(yml, 'r', encoding='utf-8') as f:
        dic = yaml.safe_load(f)
        for k in dic:
            setattr(args, k, dic[k])


def count_parameters(model):
    # count trainable parameters
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class TBLog:
    """
    Construc tensorboard writer (self.writer).
    The tensorboard is saved at os.path.join(tb_dir, file_name).
    """

    def __init__(self, tb_dir, file_name, use_tensorboard=False):
        self.tb_dir = tb_dir
        self.use_tensorboard = use_tensorboard
        if self.use_tensorboard:
            self.writer = SummaryWriter(os.path.join(self.tb_dir, file_name))

    def update(self, tb_dict, it, suffix=None, mode="train"):
        """
        Args
            tb_dict: contains scalar values for updating tensorboard
            it: contains information of iteration (int).
            suffix: If not None, the update key has the suffix.
        """
        if suffix is None:
            suffix = ''
        if self.use_tensorboard:
            for key, value in tb_dict.items():
                self.writer.add_scalar(suffix + key, value, it)


class Bn_Controller:
    """
    Batch Norm controler
    """

    def __init__(self):
        """
        freeze_bn and unfreeze_bn must appear in pairs
        """
        self.backup = {}

    def freeze_bn(self, model):
        assert self.backup == {}
        for name, m in model.named_modules():
            if isinstance(m, nn.SyncBatchNorm) or isinstance(m, nn.BatchNorm2d):
                self.backup[name + '.running_mean'] = m.running_mean.data.clone()
                self.backup[name + '.running_var'] = m.running_var.data.clone()
                self.backup[name + '.num_batches_tracked'] = m.num_batches_tracked.data.clone()

    def unfreeze_bn(self, model):
        for name, m in model.named_modules():
            if isinstance(m, nn.SyncBatchNorm) or isinstance(m, nn.BatchNorm2d):
                m.running_mean.data = self.backup[name + '.running_mean']
                m.running_var.data = self.backup[name + '.running_var']
                m.num_batches_tracked.data = self.backup[name + '.num_batches_tracked']
        self.backup = {}


class EMA:
    """
    EMA model (parameters + buffers)
    """

    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

    def load(self, ema_model):
        # parameters
        for name, param in ema_model.named_parameters():
            self.shadow[name] = param.data.clone()
        # buffers
        for name, buf in ema_model.named_buffers():
            self.shadow[name] = buf.data.clone()

    def register(self):
        # parameters
        for name, param in self.model.named_parameters():
            self.shadow[name] = param.data.clone()
        # buffers
        for name, buf in self.model.named_buffers():
            self.shadow[name] = buf.data.clone()

    def update(self):
        # parameters
        for name, param in self.model.named_parameters():
            self.shadow[name] = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
        # buffers
        for name, buf in self.model.named_buffers():
            self.shadow[name] = (1.0 - self.decay) * buf.data + self.decay * self.shadow[name]

    def apply_shadow(self):
        self.backup = {}
        # parameters
        for name, param in self.model.named_parameters():
            self.backup[name] = param.data.clone()
            param.data = self.shadow[name]
        # buffers
        for name, buf in self.model.named_buffers():
            self.backup[name] = buf.data.clone()
            buf.data = self.shadow[name]

    def restore(self):
        # parameters
        for name, param in self.model.named_parameters():
            param.data = self.backup[name]
        # buffers
        for name, buf in self.model.named_buffers():
            buf.data = self.backup[name]
        self.backup = {}
        

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

class CosineAnnealingWarmupRestarts(_LRScheduler):
    def __init__(self, optimizer, first_cycle_steps, cycle_mult=1.0, max_lr=0.1, min_lr=0.001,
                 warmup_steps=0, gamma=1.0, last_epoch=-1):
        self.first_cycle_steps = first_cycle_steps
        self.cycle_mult = cycle_mult
        self.base_max_lr = max_lr
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.warmup_steps = warmup_steps
        self.gamma = gamma

        self.cur_cycle_steps = first_cycle_steps
        self.cycle = 0
        self.step_in_cycle = last_epoch
        
        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            return [(self.min_lr + (self.max_lr - self.min_lr) * (self.step_in_cycle / self.warmup_steps)) for base_lr in self.base_lrs]
            # return [(0 + (self.max_lr - 0) * (self.step_in_cycle / self.warmup_steps)) for base_lr in self.base_lrs]
        else:
            progress = (self.step_in_cycle - self.warmup_steps) / (self.cur_cycle_steps - self.warmup_steps)
            return [self.min_lr + (self.max_lr - self.min_lr) * 
                    (1 + math.cos(math.pi * progress)) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        self.step_in_cycle += 1

        if self.step_in_cycle >= self.cur_cycle_steps + 1:
            self.cycle += 1
            self.step_in_cycle = 0
            self.cur_cycle_steps = int(self.first_cycle_steps * (self.cycle_mult ** self.cycle))
            self.max_lr = self.base_max_lr * (self.gamma ** self.cycle)  # Apply gamma to reduce the max_lr

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr
            
            
def log_client_metrics_to_wandb(results, round_num, log_keys=("loss", "accuracy"), log_selected_ids=True):
    """
    Log selected clients' metrics to wandb.
    
    Args:
        results: list of (ClientProxy, FitRes) tuples.
        round_num: current FL round.
        log_keys: which keys to log from metrics.
        log_selected_ids: whether to log the selected client IDs.
    """
    for client, fit_res in results:
        cid = client.cid
        metrics = fit_res.metrics or {}

        for key in log_keys:
            if key in metrics:
                wandb.log({f"client_{cid}_{key}": metrics[key]}, step=round_num)

    if log_selected_ids:
        selected_cids = [client.cid for client, _ in results]
        wandb.log({f"round_{round_num}_selected_clients": ",".join(map(str, selected_cids))}, step=round_num)
