# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. 
# All rights reserved.
#
# Code is originally from the EDM (https://arxiv.org/pdf/2206.00364) implementation
# from https://github.com/NVlabs/edm by NVIDIA which is licensed under CC-BY-NC-SA 4.0.
# You may obtain a copy of the License at
#
# https://creativecommons.org/licenses/by-nc-sa/4.0/
#

'''
Codes adapted from https://github.com/NVlabs/LSGM/blob/main/util/ema.py
'''
import warnings

import torch
from torch.optim.optimizer import Optimizer

class EMA(Optimizer):
    def __init__(self, opt, ema_decay):
        super(EMA, self).__init__(opt.param_groups, opt.defaults)  # Initialize base class
        # Ensure all necessary attributes are initialized
        if not hasattr(self, '_optimizer_state_dict_pre_hooks'):
            self._optimizer_state_dict_pre_hooks = {} 
        self.defaults = opt.defaults
        self.ema_decay_tmp = ema_decay
        self.ema_decay = 0
        self.apply_ema = False
        self.optimizer = opt
        self.state = opt.state
        self.param_groups = opt.param_groups
    def ema_start(self):
        self.ema_decay = self.ema_decay_tmp
        self.apply_ema = self.ema_decay > 0.

    def step(self, *args, **kwargs):
        retval = self.optimizer.step(*args, **kwargs)

        # stop here if we are not applying EMA
        if not self.apply_ema:
            return retval

        for group in self.optimizer.param_groups:
            ema, params = {}, {}
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                state = self.optimizer.state[p]

                # State initialization
                if 'ema' not in state:
                    state['ema'] = p.data.clone()

                if p.shape not in params:
                    params[p.shape] = {'idx': 0, 'data': []}
                    ema[p.shape] = []

                params[p.shape]['data'].append(p.data)
                ema[p.shape].append(state['ema'])

            for i in params:
                params[i]['data'] = torch.stack(params[i]['data'], dim=0)
                ema[i] = torch.stack(ema[i], dim=0)
                ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay)

            for p in group['params']:
                if p.grad is None:
                    continue
                idx = params[p.shape]['idx']
                self.optimizer.state[p]['ema'] = ema[p.shape][idx, :]
                params[p.shape]['idx'] += 1

        return retval

    def load_state_dict(self, state_dict):
        super(EMA, self).load_state_dict(state_dict)
        # load_state_dict loads the data to self.state and self.param_groups. We need to pass this data to
        # the underlying optimizer too.
        self.optimizer.state = self.state
        self.optimizer.param_groups = self.param_groups

    def swap_parameters_with_ema(self, store_params_in_ema):
        """ This function swaps parameters with their ema values. It records original parameters in the ema
        parameters, if store_params_in_ema is true."""

        # stop here if we are not applying EMA
        if not self.apply_ema:
            warnings.warn('swap_parameters_with_ema was called when there is no EMA weights.')
            return

        for group in self.optimizer.param_groups:
            for i, p in enumerate(group['params']):
                if not p.requires_grad:
                    continue
                ema = self.optimizer.state[p]['ema']
                if store_params_in_ema:
                    tmp = p.data.detach()
                    p.data = ema.detach()
                    self.optimizer.state[p]['ema'] = tmp
                else:
                    p.data = ema.detach()
        print('Swapped parameters with EMA parameters.')