import torch
import os
import shutil
import pandas as pd

def adjust_lr_lambda(args, epoch, optimizer, global_round=None):
    """Sets the learning rate based on global rounds or epochs."""
    epoch = epoch + 1
    lr = args.lr

    # Adjust learning rate based on global rounds if provided
    if global_round:  
        if global_round >= 160:
            lr = args.lr * 0.01
            args.lamda = 0
        elif global_round >= 90:
            lr = args.lr * 0.1
            args.lamda = 0
        else:
            lr = args.lr
            args.lamda = 0
    else:  # Adjust learning rate based on epochs (original behavior)
        if args.epochs == 120:
            if epoch > 90:
                lr = args.lr * 0.1
                args.lamda = 0
            else:
                lr = args.lr
                args.lamda = 0
        elif args.epochs == 200:
            if epoch <= 5:
                lr = args.lr * epoch / 5
                args.lamda = 0
            elif epoch >= 180:
                lr = args.lr * 0.0001
                args.lamda = 0
            elif epoch >= 160:
                lr = args.lr * 0.01
                args.lamda = 0
            else:
                lr = args.lr
                args.lamda = 0
        else:
            if epoch <= 5:
                lr = args.lr * epoch / 5
                args.lamda = 0
            elif epoch >= 60:
                 lr = args.lr * 0.1
                 args.lamda = 0
            else:
                 lr = args.lr

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k."""
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.float().topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].float().sum()
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class ResultsLog(object):
    def __init__(self, path='results.csv', plot_path=None):
        self.path = path
        self.plot_path = plot_path or (self.path + '.html')
        self.figures = []
        self.results = None

    def add(self, **kwargs):
        df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
        if self.results is None:
            self.results = df
        else:
            # Append the results (federated or non-federated results)
            self.results = pd.concat([self.results, df], axis=0, join='outer', ignore_index=True)

    def save(self, title='Training Results'):
        # If figures for plotting exist, remove the previous plot file and save new one
        if len(self.figures) > 0:
            if os.path.isfile(self.plot_path):
                os.remove(self.plot_path)
            # output_file(self.plot_path, title=title)
            # plot = column(*self.figures)
            # save(plot)
            # self.figures = []
        self.results.to_csv(self.path, index=False, index_label=False)

    def load(self, path=None):
        path = path or self.path
        if os.path.isfile(path):
            self.results = pd.read_csv(path)

    def log_global_average(self, client_results):
        """Logs the averaged results across all clients."""
        avg_results = pd.DataFrame(client_results).mean(axis=0).to_dict()
        self.add(**avg_results)

    def show(self):
        pass

    def plot(self, xs, ys, *kargs, **kwargs):
        pass

    def image(self, *kargs, **kwargs):
        pass


class AverageMeter(object):
    """Computes and stores the average and current value."""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def model_resume(args, model, model_new):
    """Resume the model from a checkpoint."""
    checkpoint = torch.load(model) 
    best_acc1 = checkpoint['best_acc1']
    args.resumed_epoch = checkpoint['best_acc1']
    model_new.module.load_state_dict(checkpoint['state_dict'])
    #optimizer.load_state_dict(checkpoint['optimizer'])
    print("=> loaded checkpoint '{}' (epoch {}), acc {}".format(args.resume, args.resumed_epoch, best_acc1))


def save_checkpoint_epoch(state, is_best, path='.', global_round=None):
    """Saves model checkpoint based on global rounds or epoch."""
    if global_round:
        filename = os.path.join(path, f'{global_round}-global_round_checkpoint.pth.tar')
    else:
        filename = os.path.join(path, f'{state["epoch"]}-th_epoch_checkpoint.pth.tar')
    
    torch.save(state, filename)
    if is_best:
        shutil.move(filename, 'model_best.pth.tar')


def load_checkpoint_epoch(epoch, path='.'):
    filename = os.path.join(path, f'{epoch}-th_epoch_checkpoint.pth.tar')
    return torch.load(filename)


def load_checkpoint_iter(iter, path='.'):
    filename = os.path.join(path, f'{iter}-th_iter_checkpoint.pth.tar')
    return torch.load(filename)


def load_checkpoint_best(path='.'):
    filename = os.path.join(path, 'model_best.pth.tar')
    return torch.load(filename)
