import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from networks import utils

class Helper:
    def create(self, full_net):
        self.full_net = full_net
    def initialize(self, device):
        self.full_net.to(device)
        self.mom1, self.mom2 = [], []
        for p in self.full_net.parameters():
            self.mom1.append(p.detach())
            self.mom2.append(p.detach())
        self.f_full_net = utils.module2functional(self.full_net)

helper = Helper()

def run_sgld_gaussian_fit(x, y, full_net, mom1, mom2, init_params=None, lossfun=nn.CrossEntropyLoss(), steps=10, burnin=5, alp=1e-3, ai_max=1e3, xs=None, ys=None):

    derived_head = (xs is not None) and (ys is not None)
    N = x.shape[0]
    alp /= N
    if init_params is not None:
        with torch.no_grad():
            for src, tgt in zip(init_params, full_net.parameters()):
                tgt.copy_(src)
    else:
        with torch.no_grad():
            for p in full_net.parameters():
                if len(p.shape) > 1:
                    nn.init.kaiming_normal_(p, nonlinearity='relu')
                else:
                    nn.init.zeros_(p)
    optim = torch.optim.Adam(params=full_net.parameters(), lr=alp)
    moment1, moment2 = None, None
    for i in range(steps):
        logits = full_net.forward(x, xs, ys) if derived_head else full_net.forward(x)
        loss = lossfun(logits, y)
        optim.zero_grad()
        loss.backward()

        with torch.no_grad():
            for p in full_net.parameters():
                if p.grad is not None:
                    p.grad *= N/2.
                    p.grad += np.sqrt(1/alp)*torch.randn_like(p)
        optim.step()
        if i == burnin:
            with torch.no_grad():
                src = nn.utils.parameters_to_vector(parameters=full_net.parameters())
                moment1 = src*1.0
                moment2 = src**2
            cnt = 1
        elif i > burnin:
            with torch.no_grad():
                src = nn.utils.parameters_to_vector(parameters=full_net.parameters())
                moment1 = (src + cnt*moment1) / (cnt+1)
                moment2 = (src**2 + cnt*moment2) / (cnt+1)
            cnt += 1
    with torch.no_grad():
        moment2 = (1./((cnt/(cnt-1))*(moment2-moment1**2))).clamp(min=1e-4, max=ai_max)
    
    with torch.no_grad():
        nn.utils.vector_to_parameters(moment1, mom1)
        nn.utils.vector_to_parameters(moment2, mom2)
    return mom1, mom2

class NIWMeta(nn.Module):
    def __init__(self, full_net, derived_head, initialize=True, spiky=0, gam0_init=1e4, gam0_max=1e8,sgld_steps=5, sgld_burnin=2, sgld_alp=0.001, sgld_ai_max=1e3, ):
        super().__init__()
        helper.create(full_net)

        self.m0 = nn.ParameterList([nn.Parameter(p*1.0) for p in full_net.parameters()])
        self.ugam0 = nn.ParameterList([nn.Parameter(p*0.0+np.log(gam0_init)) for p in full_net.parameters()])
        if initialize:
            for p in self.m0:
                if len(p.shape) > 1:
                    nn.init.kaiming_normal_(p, nonlinearity='relu')
                else:
                    nn.init.zeros_(p)
        self.derived_head = derived_head
        self.spiky = spiky
        self.gam0_max = gam0_max
        self.sgld_opts = {'steps': sgld_steps,'burnin': sgld_burnin,'alp': sgld_alp,'ai_max': sgld_ai_max}
        self.d = nn.utils.parameters_to_vector(full_net.parameters()).numel()

    def get_gam0(self): return [torch.exp(p) for p in self.ugam0]
    def initialize_helper(self, device): helper.initialize(device)

    def forward(self, xs, ys, xq, yq, lossfun, query_only_for_loss=True):
        if query_only_for_loss:
            x, y = xq, yq
        else:
            x = torch.cat([xs, x], 0)
            y = torch.cat([ys, y], 0)
        N = x.shape[0]
        m_bar, A_bar = run_sgld_gaussian_fit(x, y, helper.full_net, helper.mom1, helper.mom2, init_params=self.m0, 
            lossfun=lossfun, **self.sgld_opts, xs = xs if self.derived_head else None, ys = ys if self.derived_head else None,)
        gam0 = self.get_gam0()
        m_star = []
        gam_star = None if self.spiky else []
        for m_bar_p, A_bar_p, m0_p, gam0_p in zip(m_bar, A_bar, self.m0, gam0):
            m_star.append((A_bar_p*m_bar_p + gam0_p*m0_p)/(A_bar_p + gam0_p))
            if not self.spiky:
                gam_star.append(A_bar_p + gam0_p)
        if self.spiky:
            th = m_star
        else:
            th = [ m_ + torch.randn_like(m_)/g_.sqrt() for m_, g_ in zip(m_star, gam_star) ]
        if self.derived_head:
            logits = helper.f_full_net.forward(x, xs, ys, params=th)
        else:
            logits = helper.f_full_net.forward(x, params=th)
        loss = lossfun(logits, y)
        loss_report = loss.item()
        m0_vec = nn.utils.parameters_to_vector(parameters=self.m0)
        gam0_vec = nn.utils.parameters_to_vector(parameters=gam0)
        m_star_vec = nn.utils.parameters_to_vector(parameters=m_star)
        if self.spiky:
            meta_loss = loss + 0.5 * (-gam0_vec.log().sum() + (gam0_vec*(m_star_vec-m0_vec)**2).sum() - self.d)/N
        else:
            gam_star_vec = nn.utils.parameters_to_vector(parameters=gam_star)
            meta_loss = loss +0.5*(-gam0_vec.log().sum() + gam_star_vec.log().sum()+ (gam0_vec/gam_star_vec).sum() + (gam0_vec*(m_star_vec-m0_vec)**2).sum() - self.d
            )/N
        return meta_loss, logits, loss_report

    def evaluate(self, xs, ys, xq, yq, lossfun, steps=20, lr=1e-3, nsamps=0, regr=True):
        Ns = xs.shape[0]
        m0_vec = nn.utils.parameters_to_vector(parameters=self.m0).data
        gam0_vec = nn.utils.parameters_to_vector(parameters=self.get_gam0()).data

        m = nn.ParameterList([nn.Parameter(p*1.0) for p in self.m0])
        if not self.spiky:
            ugam = nn.ParameterList([nn.Parameter(p*1.0) for p in self.ugam0])
        if self.spiky:
            optim = torch.optim.Adam(params=list(m.parameters()), lr=lr)
        else:
            optim = torch.optim.Adam(params=list(m.parameters())+list(ugam.parameters()), lr=lr)
        for i in range(steps):

            if self.spiky:
                th = m.parameters()
            else:
                gam = [torch.exp(p) for p in ugam]
                th = [ m_ + torch.randn_like(m_)/g_.sqrt() for m_, g_ in zip(m, gam) ]
            if self.derived_head:
                logits = helper.f_full_net.forward(xs, xs, ys, params=th)
            else:
                logits = helper.f_full_net.forward(xs, params=th)
            loss = lossfun(logits, ys)
            m_vec = nn.utils.parameters_to_vector(parameters=m)
            if self.spiky:
                loss = loss + 0.5 * ((gam0_vec*(m_vec-m0_vec)**2).sum() - self.d - gam0_vec.log().sum()) /Ns
            else:
                gam_vec = nn.utils.parameters_to_vector(parameters=gam)
                loss = loss + 0.5 *(gam_vec.log().sum() + (gam0_vec/gam_vec).sum() + (gam0_vec*(m_vec-m0_vec)**2).sum() - self.d-gam0_vec.log().sum()
                )/Ns
            optim.zero_grad()
            loss.backward()
            optim.step()
        
        if nsamps == 0 or self.spiky:
            with torch.no_grad():
                for src, tgt in zip(m, helper.full_net.parameters()):
                    tgt.copy_(src)
            helper.full_net.eval()
            with torch.no_grad():
                if self.derived_head:
                    logits = helper.full_net(xq, xs, ys)
                else:
                    logits = helper.full_net(xq)
            loss = lossfun(logits, yq)
            helper.full_net.train()
        else: 
            gam = [torch.exp(p) for p in ugam]

            logits_all = []
            for _ in range(nsamps):
                th = [ m_ + torch.randn_like(m_)/g_.sqrt() for m_, g_ in zip(m, gam) ]
                with torch.no_grad():
                    for src, tgt in zip(th, helper.full_net.parameters()):
                        tgt.copy_(src)
                helper.full_net.eval()
                with torch.no_grad():
                    if self.derived_head:
                        logits_ = helper.full_net(xq, xs, ys)
                    else:
                        logits_ = helper.full_net(xq)
                helper.full_net.train()
                logits_all.append(logits_)
            logits_all = torch.stack(logits_all, 2)
            if regr:
                logits = logits_all.mean(-1)
            else:
                logits = F.log_softmax(logits_all,1).logsumexp(-1) - np.log(nsamps)
            loss = lossfun(logits, yq)
        return loss, logits, logits_all
