import os
import shutil
import copy

import torch
import yaml


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


def save_config_file(model_checkpoints_folder, args):
    if not os.path.exists(model_checkpoints_folder):
        os.makedirs(model_checkpoints_folder)
        with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
            yaml.dump(args, outfile, default_flow_style=False)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

    

class ModelWrapper:
    def __init__(self, net):
        self.ema_net = copy.deepcopy(net)
        
    def train(self):
        self.ema_net.train()
        
    def eval(self):
        self.ema_net.eval()

    def __call__(self, *args, **kwargs):
        return self.ema_net(*args, **kwargs)
    
    def to(self, device):
        return self
    
    def state_dict(self):
        return self.ema_net.state_dict()
    
class ModelEMA(ModelWrapper):
    def __init__(self, net, ema):
        super().__init__(net)
        self.ema_net = copy.deepcopy(net)
        self.net = net
        self.ema = ema

    def update(self):
        for v1, v2 in zip(self.ema_net.state_dict().values(), self.net.state_dict().values()):
            if v1.dtype == torch.long:
                v1.copy_(v2)
            else:
                v1.mul_(self.ema).add_(v2, alpha=1 - self.ema)


class ModelAverage(ModelWrapper):
    def __init__(self, net, gamma=8.0):
        super().__init__(net)
        self.ema_net = copy.deepcopy(net)
        self.net = net
        self.gamma = gamma
        self.t = 1

    def update(self):
        t = self.t
        for v1, v2 in zip(self.ema_net.state_dict().values(), self.net.state_dict().values()):
            if v1.dtype == torch.long:
                v1.copy_(v2)
            else:
                v1.mul_(1 - ((self.gamma + 1) / (self.gamma + t))).add_(v2, alpha=(self.gamma + 1) / (self.gamma + t))
        self.t += 1
