import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torch.nn.init as init
import yaml

def read_yaml_and_pass_to_argparse(path, args):
    with open(path) as f:
        my_dict = yaml.safe_load(f)
    for k, v in my_dict.items():
        setattr(args, k, v)




class InfiniteIterator:
    def __init__(self, loader):
        self.loader = loader

    def __next__(self):
        try:
            out = next(self.iter)
        except:
            self.iter = iter(self.loader)
            out = next(self.iter)
        return out

def weights_init(init_type='gaussian'):
    def init_fun(m):
        classname = m.__class__.__name__
        if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
            # print m.__class__.__name__
            if init_type == 'gaussian':
                init.normal_(m.weight.data, 0.0, 0.02)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=0.02)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=math.sqrt(2))
            elif init_type == 'default':
                pass
            else:
                assert 0, "Unsupported initialization: {}".format(init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)

    return init_fun

from copy import deepcopy
from collections import OrderedDict
from sys import stderr

# for type hint
from torch import Tensor


class EMA(nn.Module):
    def __init__(self, model: nn.Module, decay: float):
        super().__init__()
        self.decay = decay

        self.model = model
        self.shadow = deepcopy(self.model)

        for param in self.shadow.parameters():
            param.detach_()

    @torch.no_grad()
    def update(self, model):
        if not self.training:
            print("EMA update should only be called during training", file=stderr, flush=True)
            return

        for shadow_param, param in zip(self.shadow.parameters(), model.parameters()):
            # see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
            # shadow_variable -= (1 - decay) * (shadow_variable - variable)
            shadow_param.sub_((1. - self.decay) * (shadow_param - param))
        #print(self.shadow.final.weight.view(-1)[:10],  ' >>>weightss',
        #      model.final.weight.view(-1)[:10], ' >>>weightssmodel cur')

    def forward(self, *args, **kwargs) -> Tensor:
        return self.shadow(*args, **kwargs)


class Logger(object):
    def __init__(self, log_path):
        self.log = open(log_path, 'a')

    def write(self, dict, step):
        message = 'Step: %06d ' % step
        for k, v in dict.items():
            message += '%s: %.4f ' % (k, v)
        self.log.write(message + '\n')
        self.log.flush()

    def flush(self):
        pass