import math
import torch
import numpy as np

def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.lr
    # cosine lr schedule
    cosine_lr = 0.5 * (1. + math.cos(math.pi * epoch / args.n_epochs))

    if args.algorithm == 'simclr' and args.lr_min is not None:
        lr = args.lr_min + (lr - args.lr_min) * cosine_lr
    else:
        lr *= cosine_lr
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def save_checkpoint(args, epoch, model, optimizer, acc, filename, msg):
    state = {
        'epoch': epoch,
        'arch': args.arch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'top1_acc': acc
    }
    torch.save(state, filename)
    print(msg, filename)


def load_checkpoint(model, optimizer, filename):
    checkpoint = torch.load(filename, map_location='cuda:0')
    start_epoch = checkpoint['epoch']
    print(f'loaded checkpoint from epoch {start_epoch} --- {filename}')
    state_dict = checkpoint['state_dict']
    try:
        model.load_state_dict(checkpoint['state_dict'])
    except:
        try:
            model.load_state_dict({k.replace('module.', ''): v for k, v in state_dict.items()})
        except:
            try:
                model.load_state_dict({f'module.{k}': v for k, v in state_dict.items()})
            except:
                state_dict = {k: v for k, v in checkpoint['state_dict'].items() if not (k.startswith('module.backbone.') or k.startswith('module.projector.') or k.startswith('backbone.') or k.startswith('projector.') )}
                try:
                    model.load_state_dict(state_dict)
                except:
                    try:
                        model.load_state_dict({k.replace('module.', ''): v for k, v in state_dict.items()})
                    except:
                        model.load_state_dict({f'module.{k}': v for k, v in state_dict.items()})
                
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])
    return start_epoch, model, optimizer

def load_pretrained_checkpoint(model, optimizer, filename):
    checkpoint = torch.load(filename, map_location='cuda:0')
    start_epoch = checkpoint['epoch']
    try:
        model.load_state_dict(checkpoint['state_dict'])
    except:
        model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()})
    
    optimizer.load_state_dict(checkpoint['optimizer'])
    print(f'loaded checkpoint from epoch {start_epoch} --- {filename}')
    return start_epoch, model, optimizer
