import torch
import numpy as np
import torch.nn as nn


def transpose_np(lst_lst):
    return [np.array(item) for item in list(map(list, zip(*lst_lst)))]


def mean_np(t):
    return t.mean(-1).cpu().numpy()


def length(obj):
    with torch.no_grad():
        if obj is None: return 0
        return (torch.norm(obj, dim=-1)**2 / obj.shape[-1]).cpu().numpy()


def length_w(obj):
    with torch.no_grad():
        if obj is None: return 0
        # temporary for arbitrary dimension
        return (torch.norm(obj)**2).cpu().numpy() / np.prod(obj.shape)


def get_w_grads(model):
    w_grads = list(); wg_len = list(); bg_len = list()
    w_len = list(); b_len = list()
    for name, m in model.named_modules():
        if 'decs' in name or 'out' in name or 'in' in name: continue # to include ff along with pc
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            w_grad = m.weight.grad.detach(); w = m.weight.detach()
            w_grads.append(w_grad)
            wg_len.append(length_w(w_grad))
            w_len.append(length_w(w))
            if m.bias is not None:
                b_grad = m.bias.grad.detach(); b = m.bias.detach()
                bg_len.append(length(b_grad))
                b_len.append(length(b))
            else : b_len.append(0)
            # print(f"Gradient of {name} parameters: {w_grad}")
    return w_grads, wg_len, bg_len, w_len, b_len


def get_sigmas(args, min_val_s, step_val_s):
    return [round(min_val_s * (step_val_s**i),3) if args.step_exp \
                else round(min_val_s + step_val_s*i,3) for i in \
                range(args.n_conds)]


def get_etas(args, min_val_e, step_val_e):
    return [round(min_val_e + step_val_e*i,3) for i in range(args.n_conds_e)]



def init_zs_db(args, ldt):
    return [torch.normal(torch.zeros([ldt, args.ds[i+1]]),1.0) \
                    .to(args.device) for i in range(args.n_layers-2)]


def set_ps(progress, half_L, L):
    ps = np.zeros((L))
    progress += 0.000000001
    ps[:int(progress * half_L)+1] = 1.
    ps[L - int(progress * half_L)-1:] = 1.
    return ps


def get_len_t(w_grads, vs):
    len_lst = [length(v) for v in vs]
    len_lst.append(length_w(w_grads))
    return len_lst


def get_len_l(T, w_grads,hs, zs, zs_t, zhs, deltas, deltas_t):
    len_wds = [length_w(w_grad) for w_grad in w_grads] # w_grad.sh == n_layers x dim x dim
    len_hs = list(); len_zs = list(); len_zhs = list(); len_deltas = list()
    # n_layers x bsz x dim
    for h, z, zh, delta in zip(hs, zs, zhs, deltas):
        len_h = length(h)
        len_z = length(z)
        len_zh = length(zh)
        len_delta = length(delta)
        len_hs.append(len_h)
        len_zs.append(len_z)
        len_zhs.append(len_zh)
        len_deltas.append(len_delta)
    # T X n_layers X bsz X dim
    len_zs_t = list(); len_deltas_t = list()
    for t in range(T):
        len_zs_ = list(); len_deltas_ = list()
        for z, delta in zip(zs_t[t], deltas_t[t]):
            len_z_t = length(h)
            len_delta_t = length(delta)
            len_zs_.append(len_z_t)
            len_deltas_.append(len_delta_t)
        len_zs_t.append(len_zs_)
        len_deltas_t.append(len_deltas_)
    return len_hs, len_zs, len_zs_t, len_zhs, len_deltas, len_deltas_t,\
            len_wds


def get_ps(log_dt, z_encs, zs, model):
    with torch.no_grad():
        weights = list()
        for name, m in model.named_modules():
            if 'decs' in name or 'in' in name or 'out' in name: continue
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                weights.append(m.weight)
        # idxs: 1 ~ L-1
        p_hats = [mean_np(z * zh) for z, zh in zip(zs[1:], z_encs)]
        # idxs: 0 ~ L-2
        p_tlds = [mean_np(zs[i] * torch.matmul(zs[i+1], w)) for i, w in \
                    enumerate(weights)]
        # idxs: 0 ~ L-2
        p_chks = [mean_np(z * torch.matmul(zh, w)) for z, zh, w in \
                    zip(zs[:-1], z_encs, weights)]
        log_dt['ps/p_hat'] = p_hats # [:-1]
        log_dt['ps/p_tld'] = p_tlds # [1:]
        log_dt['ps/p_chk'] = p_chks # [1:]
        log_dt['ps/p'] = log_dt['len/z_len']
        log_dt['ps/q'] = log_dt['len/d_len']
        log_dt['ps/v'] = log_dt['len/w_len']


def tensorize(objs):
    objs_t = list()
    for obj in objs:
        objs_t.append(torch.stack(obj))
    return objs_t


def init_param(model, args, sigma_w, sigma_b):
    for name, m in model.named_modules():
        if 'decs' in name: continue # to include ff along with pc
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            l = int(name.split('.')[1])
            if args.orthogonal_testing == True:
                torch.nn.init.orthogonal_(m.weight)
                m.weight.data *= sigma_w
                if sigma_b != 0:
                    torch.nn.init.normal_(m.bias, mean=0, std=sigma_b)
                else: torch.fill_(m.bias, 0)
            elif args.act == 'relu':
                # kaiming init
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if sigma_b != 0:
                    m.bias.data *= sigma_b
                    # torch.nn.init.normal_(m.bias, mean=0, std=sigma_b)
                else: torch.fill_(m.bias, 0)
            else:
                if args.param_init == 'normal':
                    d_rev = 2.0/(args.ds[l]+args.ds[l+1]) if args.train\
                        else 1/args.ds[l]
                    torch.nn.init.normal_(m.weight, mean=0, std=sigma_w*np.sqrt(d_rev))
                    if sigma_b != 0:
                       # m.bias.data *= sigma_b
                        torch.nn.init.normal_(m.bias, mean=0, std=sigma_b)
                    else: torch.fill_(m.bias, 0)
                elif args.param_init == 'uniform':
                    # d_rev = 6.0/(args.ds[l]+args.ds[l+1]) if args.train\
                    #     else 1/args.ds[l]
                    if args.test_sw: args.gain = sigma_w
                    torch.nn.init.xavier_uniform_(m.weight, gain=args.gain)
                    # fill zero to bias
                    torch.fill_(m.bias, 0)
                else: raise NotImplementedError
