import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import higher

class BaseNet(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        for m in self.backbone.modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                m.track_running_stats = False
        self.bias = nn.Parameter(torch.FloatTensor(1).fill_(0), requires_grad=True)
        self.scale_cls = nn.Parameter(torch.FloatTensor(1).fill_(10), requires_grad=True)

    def cos_classifier(self, w, f):

        """
        w.shape = nC, d
        f.shape = M, d
        """
        f = F.normalize(f, p=2, dim=f.dim()-1, eps=1e-12)
        w = F.normalize(w, p=2, dim=w.dim()-1, eps=1e-12)

        cls_scores = f @ w.transpose(0, 1)
        cls_scores = self.scale_cls * (cls_scores + self.bias)
        return cls_scores

    def forward(self, supp_x, supp_y, x=None):
        
        """
        supp_x.shape = [nSupp, C, H, W]
        supp_y.shape = [nSupp]
        x.shape = [nQry, C, H, W]
        """        
        num_classes = supp_y.max() + 1 # NOTE: assume B==1
        supp_f = self.backbone.forward(supp_x)
        
        supp_y_1hot = F.one_hot(supp_y, num_classes).transpose(0, 1)
        # compute cluster centroids
        prototypes = supp_y_1hot.float() @ supp_f
        prototypes = prototypes / supp_y_1hot.sum(dim=1, keepdim=True) # NOTE: may div 0 if some classes got 0 images

        if x==None:
            feat = supp_f
        else:
            feat = self.backbone.forward(x)
        logits = self.cos_classifier(prototypes, feat)
        return logits

def module2functional(torch_net):
    f_net = higher.patch.make_functional(module=torch_net)
    f_net._fast_params = [[]]
    f_net.track_higher_grads = False
    for m in f_net.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.running_mean = None
            m.running_var = None
            m.num_batches_tracked = None
    
    return f_net

class NIWMetaHelper:
    def initialize(self, backbone):
        self.base_net = BaseNet(backbone)

    def to(self, device):

        self.base_net.to(device)
        self.mom1, self.mom2 = [], []
        for p in self.base_net.parameters():
            self.mom1.append(p.detach())
            self.mom2.append(p.detach())
        self.f_base_net = module2functional(self.base_net)

helper = NIWMetaHelper()

def mvdigamma(a, d=1):
    return torch.digamma(a + 0.5*(1.-torch.arange(1,d+1).to(a))).sum()

def mvpolygamma(a, d, order=1):
    return torch.polygamma(order, a + 0.5*(1.-torch.arange(1,d+1).to(a))).sum()

def run_sgld_gaussian_fit(xs, ys, x, y, base_net, mom1, mom2, init_params=None, 
    lossfun=torch.nn.CrossEntropyLoss(), steps=100, burnin=50, alp=0.0001,ai_max=1e3):
    N = x.shape[0]
    alp /= N
    if init_params is not None:
        with torch.no_grad():
            for src, tgt in zip(init_params, base_net.parameters()):
                tgt.copy_(src)
    else:
        with torch.no_grad():
            for p in base_net.parameters():
                if len(p.shape) > 1:
                    nn.init.kaiming_normal_(p, nonlinearity='relu')
                else:
                    nn.init.zeros_(p)
    optim = torch.optim.Adam(params=base_net.parameters(), lr=alp)
    
    moment1, moment2 = None, None
    for i in range(steps):
        logits = base_net(xs, ys, x)
        loss = lossfun(logits, y)
        optim.zero_grad()
        loss.backward()
        with torch.no_grad():
            for p in base_net.parameters():
                if p.grad is not None:
                    p.grad *= N/2.
                    p.grad += np.sqrt(1/alp)*torch.randn_like(p)

        optim.step()
        if i == burnin:
            with torch.no_grad():
                src = nn.utils.parameters_to_vector(parameters=base_net.parameters())
                moment1 = src*1.0
                moment2 = src**2
            cnt = 1
        elif i > burnin:
            with torch.no_grad():
                src = nn.utils.parameters_to_vector(parameters=base_net.parameters())
                moment1 = (src + cnt*moment1) / (cnt+1)
                moment2 = (src**2 + cnt*moment2) / (cnt+1)
            cnt += 1
    with torch.no_grad():
        moment2 = (1./((cnt/(cnt-1))*(moment2-moment1**2))).clamp(min=1e-4, max=ai_max)
    
    with torch.no_grad():
        nn.utils.vector_to_parameters(moment1, mom1)
        nn.utils.vector_to_parameters(moment2, mom2)
    return mom1, mom2

class NIWMeta(nn.Module):
    def __init__(self, backbone,  use_gami, n0_init, gam0_init,steps, burnin, alp, ai_max, lossfun):
        super().__init__()
        helper.initialize(backbone)
        self.m0 = nn.ParameterList([nn.Parameter(p*1.0) for p in helper.base_net.parameters()])
        self.ugam0 = nn.ParameterList([nn.Parameter(p*0.0+np.log(gam0_init)) for p in helper.base_net.parameters()])
        self.un0 = nn.Parameter((torch.zeros(1)+np.log(n0_init)).to(self.m0[0].device))
        self.use_gami = use_gami
        self.lossfun = lossfun
        self.steps = steps
        self.alp = alp

        self.opts = {'steps': steps,'burnin': burnin, 'alp': alp,'lossfun': lossfun,'ai_max': ai_max}
        self.d = nn.utils.parameters_to_vector(helper.base_net.parameters()).numel()

    def helper_to(self, device):helper.to(device)
    def get_gam0(self): return [torch.exp(p) for p in self.ugam0]
    def get_n0(self): return torch.exp(self.un0)

    def forward(self, xs, ys, x, y):
        N = x.shape[0]
        m_bar, gam_bar = run_sgld_gaussian_fit(xs, ys, x, y, helper.base_net, helper.mom1, helper.mom2, init_params=self.m0, **self.opts)
        n0, gam0 = self.get_n0(), self.get_gam0()
        m_star = []
        if self.use_gami:
            gam_star = []
        for m_bar_p, gam_bar_p, m0_p, gam0_p in zip(m_bar, gam_bar, self.m0, gam0):
            m_star.append((gam_bar_p*m_bar_p + n0*gam0_p*m0_p) /(gam_bar_p + n0*gam0_p))
            if self.use_gami:
                gam_star.append(gam_bar_p + n0*gam0_p)
        if self.use_gami:
            th = [ m_ + torch.randn_like(m_)/g_.sqrt() for m_, g_ in zip(m_star, gam_star) ]
        else:
            th = m_star
        logits = helper.f_base_net.forward(xs, ys, x, params=th)
        loss = self.lossfun(logits, y)
        loss_report = loss.item()
        m0_vec = nn.utils.parameters_to_vector(parameters=self.m0)
        gam0_vec = nn.utils.parameters_to_vector(parameters=gam0)
        m_star_vec = nn.utils.parameters_to_vector(parameters=m_star)
        if self.use_gami:
            gam_star_vec = nn.utils.parameters_to_vector(parameters=gam_star)
            meta_loss = loss + 0.5 *(
                -gam0_vec.log().sum() + gam_star_vec.log().sum() - self.d*mvdigamma(n0/2) + n0*(gam0_vec/gam_star_vec).sum() + 
                n0*(gam0_vec*(m_star_vec-m0_vec)**2).sum() - self.d*(np.log(2)+1))/ N
        else:
            meta_loss = loss + 0.5 * (-gam0_vec.log().sum() - self.d*mvdigamma(n0/2) + 
                n0*(gam0_vec*(m_star_vec-m0_vec)**2).sum() - self.d*(np.log(2)+1))/ N
        return meta_loss, loss_report

    def evaluate(self, xs, ys, xq, yq, nsamps=0):
        Ns = xs.shape[0]
        n0, gam0 = self.get_n0().data, self.get_gam0()
        m0_vec = nn.utils.parameters_to_vector(parameters=self.m0).data
        gam0_vec = nn.utils.parameters_to_vector(parameters=gam0).data

        m = nn.ParameterList([nn.Parameter(p*1.0) for p in self.m0])
        if self.use_gami:
            ugam = nn.ParameterList([nn.Parameter(p+torch.log(n0+self.d+2)) for p in self.ugam0])
        if self.use_gami:
            optim = torch.optim.Adam(params=list(m.parameters())+list(ugam.parameters()), lr=self.alp)
        else:
            optim = torch.optim.Adam(params=list(m.parameters()), lr=self.alp)
        for i in range(self.steps):

            if self.use_gami:
                gam = [torch.exp(p) for p in ugam]
                th = [ m_ + torch.randn_like(m_)/g_.sqrt() for m_, g_ in zip(m, gam) ]
            else:
                th = m.parameters()
            logits = helper.f_base_net.forward(xs, ys, x=None, params=th)
            loss = self.lossfun(logits, ys)
            m_vec = nn.utils.parameters_to_vector(parameters=m)
            if self.use_gami:
                gam_vec = nn.utils.parameters_to_vector(parameters=gam)
                loss = loss + 0.5 * (gam_vec.log().sum() + (n0+self.d+2)*(gam0_vec/gam_vec).sum() + 
                    (n0+self.d+2)*(gam0_vec*(m_vec-m0_vec)**2).sum())/ Ns
            else:
                loss = loss + 0.5 * ((n0+self.d+2)*(gam0_vec*(m_vec-m0_vec)**2).sum()) /Ns
            optim.zero_grad()
            loss.backward()
            optim.step()
        
        if nsamps == 0 or self.use_gami == 0:
            with torch.no_grad():
                for src, tgt in zip(m, helper.base_net.parameters()):
                    tgt.copy_(src)
            helper.base_net.eval()
            with torch.no_grad():
                logits = helper.base_net(xs, ys, x=xq)
            loss = self.lossfun(logits, yq)
            helper.base_net.train()
        else: 
            gam = [torch.exp(p) for p in ugam]

            logits = []
            for _ in range(nsamps):
                th = [ m_ + torch.randn_like(m_)/g_.sqrt() for m_, g_ in zip(m, gam) ]
                with torch.no_grad():
                    for src, tgt in zip(th, helper.base_net.parameters()):
                        tgt.copy_(src)
                helper.base_net.eval()
                with torch.no_grad():
                    logits_ = helper.base_net(xs, ys, x=xq)
                helper.base_net.train()
                logits.append(logits_)
            logits = torch.stack(logits, 2)
            logits = F.log_softmax(logits,1).logsumexp(-1) - np.log(nsamps)
            
            loss = self.lossfun(logits, yq)
        return loss, logits
        