import gc
import os
import pickle
from this import d
import numpy as np
from anal.reg import get_omega, get_beta
from anal.util import get_ps
from data.dl_getter import nDT
from anal.util import transpose_np


# print

def p_wb(model):
    l = 1
    w = model.f_loc[l].fc.weight
    b = model.f_loc[l].fc.bias
    wn = w.norm(); bn = b.norm()
    omega = get_omega(w)
    beta = get_beta(b)
    # print(omega.item(), beta.item(), wn.item(), bn.item())


name_lst = ['z_len', 'zh_len', 'h_len', 'd_len', 'loss', 'w_len', 'wg_len', 'bg_len']


class LenT:

    def __init__(self, len_np_lst, args, runs, sigma_ws, sigma_bs, etas):
        try:
            n_inner_iter = int(nDT[args.dataset]/args.bsz)
        except:
            n_inner_iter = 1
        self.args = args
        n_r = runs; n_w = len(sigma_ws); n_b = len(sigma_bs); n_e = len(etas)
        self.set_len(len_np_lst)
        self.i_r, self.i_w, self.i_b, self.i_et = -1,-1,-1,-1
        # r: run, w: sigma_w, b: sigma_b, et: eta

    def set_len(self, len_np_lst):
        self.z_len= len_np_lst[0]+1e-30
        self.zh_len= len_np_lst[1]+1e-30
        self.h_len= len_np_lst[2]+1e-30
        self.d_len= len_np_lst[3]+1e-30
        self.zds = self.z_len, self.zh_len, self.h_len, self.d_len
        self.loss_tr= len_np_lst[4]+1e-30
        self.wg_len= len_np_lst[5]+1e-30
        self.bg_len= len_np_lst[6]+1e-30
        self.loss_inf= len_np_lst[7]+1e-30
        self.wg_len_inf= len_np_lst[8]+1e-30
        self.bg_len_inf= len_np_lst[9]+1e-30
        self.z_len_tr= len_np_lst[10]+1e-30
        self.d_len_tr= len_np_lst[11]+1e-30
        self.zds_tr = self.z_len_tr, self.d_len_tr
        self.w_len= len_np_lst[12]+1e-30
        self.w_len_inf= len_np_lst[13]+1e-30
        self.wbs = self.wg_len, self.bg_len, self.w_len, self.loss_tr
        self.wbs_inf = self.wg_len_inf, self.bg_len_inf, self.w_len_inf

    def get_stat(self):
        zd_stat = [(item[:, :, :, :, :, :-(i>0)].mean((0, -1)), \
                    item[:, :, :, :, :, :-(i>0)].std((0, -1)))
                    if i>0 else (item.mean((0, -1)), item.std((0, -1)))
                    for i, item in enumerate(self.zds)]
        zd_stat.append((self.loss_inf[:, :, :, :, :, 1:].mean(0),
                        self.loss_inf[:, :, :, :, :, 1:].std(0)))
        qpratio_stat = np.divide(self.d_len[:,:,:,:,:,:-1],
                                 self.z_len[:,:,:,:,:,1:]+1e-9)
        zd_stat.append((qpratio_stat.mean((0, -1)), qpratio_stat.std((0, -1))))
        wb_stat = [(item.mean(0), item.std(0)) for item in self.wbs]
        wb_inf_stat = [(item.mean(0), item.std(0)) for item in self.wbs_inf]
        zd_tr_stat = [(item.mean((0, -1)), item.std((0, -1))) for item in self.zds_tr]
        return zd_stat, wb_stat, wb_inf_stat, zd_tr_stat

    def get_stat_plot(self):
        zd_stat_lst, wb_stat_lst, wb_inf_stat_lst, zd_tr_stat_lst = self.get_stat()
        zd = transpose_np([(zd_stat_lst[0][0][:,:,:,:,1:],
                            zd_stat_lst[0][1][:,:,:,:,1:])]
                        +zd_stat_lst[1:-2]) # excluding loss_inf and qpratio
        wb = transpose_np(wb_stat_lst[:-1])
        wb_inf = transpose_np(wb_inf_stat_lst)
        return zd_stat_lst, wb_stat_lst, wb_inf_stat_lst, zd, wb, wb_inf, zd_tr_stat_lst


class LenLog:

    def __init__(self, args):
        try:
            n_inner_iter = int(nDT[args.dataset]/args.bsz)
        except:
            n_inner_iter = 1
        self.args = args
        self.L = args.n_layers; self.bsz = args.bsz
        zd_lst = [np.zeros([args.T, self.L, self.bsz]) for _ in range(4)]
        self.zds = self.z_len, self.zh_len, self.h_len, self.d_len = zd_lst
        self.loss_inf = np.zeros([args.T, self.L])
        n_epochs =  args.epochs if args.len_all  else 1
        zd_lst_tr = [np.zeros([n_epochs * n_inner_iter, self.L, self.bsz]) for _ in range(2)]
        self.zds_tr = self.z_len_tr, self.d_len_tr = zd_lst_tr
        wb_lst = [np.zeros([n_epochs * n_inner_iter, self.L-1]) for _ in range(4)]
        self.wbs = self.wg_len, self.bg_len, self.w_len, self.loss_tr = wb_lst
        wb_lst_inf = [np.zeros([args.T, self.L-1]) for _ in range(3)]
        self.wbs_inf = self.wg_len_inf, self.bg_len_inf, self.w_len_inf = wb_lst_inf
        self.i_ei, self.i_i = -1,-1
        self.n = n_inner_iter
        # ei: epoch & batch_iter, i:inference_iter, l: layer

    def set_zh(self, i_l, vals, is_L=False):
        self.z_len[self.i_i, i_l, :] = vals[0]
        if not is_L:
            self.zh_len[self.i_i, i_l, :] = vals[1]
            self.h_len[self.i_i, i_l, :] = vals[2]

    def set_d(self, i_l, val):
        self.d_len[self.i_i, i_l-1, :] = val

    def set_loss_inf(self, i_l, loss):
        self.loss_inf[self.i_i, i_l] = loss

    def set_zd_tr(self, i_l, z_len, d_len=None):
        self.z_len_tr[self.i_ei, i_l, :] = z_len
        if i_l: self.d_len_tr[self.i_ei, i_l, :] = d_len

    def set_wb(self, wg_len, bg_len, w_len, loss):
        self.wg_len[self.i_ei, :] = wg_len
        self.bg_len[self.i_ei, :] = bg_len
        self.w_len[self.i_ei, :] = w_len
        self.loss_tr[self.i_ei, :] = loss

    def set_wb_inf(self, wg_len, bg_len, w_len):
        self.wg_len_inf[self.i_i, :] = wg_len
        self.bg_len_inf[self.i_i, :] = bg_len
        self.w_len_inf[self.i_i, :] = w_len

    def save_pkl(self, i_run, sigma_w, sigma_b, eta):
        # len_lst = [self.z_len, self.zh_len, self.h_len, self.d_len, self.loss_inf, \
        #            self.w_len_inf, self.bg_len_inf, self.loss_tr, self.w_len, self.bg_len]
        s = f"{i_run},{sigma_w},{sigma_b},{eta}"
        log_path = os.path.join(self.args.log_dir, str(self.args.T), s)
        if not os.path.exists(os.path.join(self.args.log_dir, str(self.args.T))):
            os.makedirs(os.path.join(self.args.log_dir, str(self.args.T)))
        with open(log_path,"ab") as fw:
            pickle.dump(self, fw) # pickle.dump(len_lst, fw)

    # deprecated
    def get_stat(self):
        zd_stat = [(item[:, :-(i>0)].mean((0, -1)), \
                    item[:, :-(i>0)].std((0, -1)))
                    if i>0 else (item.mean((0, -1)), item.std((0, -1)))
                    for i, item in enumerate(self.zds)]
        zd_stat.append((self.loss_inf[:, 1:].mean(0),
                        self.loss_inf[:, 1:].std(0)))
        qpratio_stat = np.divide(self.d_len[:,:-1],
                                 self.z_len[:,1:]+1e-9)
        zd_stat.append((qpratio_stat.mean((0, -1)), qpratio_stat.std((0, -1))))
        wb_stat = [(item.mean(0), item.std(0)) for item in self.wbs]
        wb_stat_inf = [(item.mean(0), item.std(0)) for item in self.wbs_inf]
        zd_stat_tr = [(item.mean(0), item.std(0)) for item in self.zds_tr]
        return zd_stat, wb_stat, wb_stat_inf, zd_stat_tr

    def inc(self, which=''):
        if which == 'e':
            self.i_ei += 1
            self.i_i = -1
        elif which == 'i':
            self.i_i += 1
        else:
            raise NotImplementedError

    def clip_T(self, to_T):
        import copy
        temp = copy.deepcopy(self.z_len[:to_T])
        del self.z_len
        self.z_len = temp
        temp = copy.deepcopy(self.zh_len[:to_T])
        del self.zh_len
        self.zh_len = temp
        temp = copy.deepcopy(self.h_len[:to_T])
        del self.h_len
        self.h_len = temp
        temp = copy.deepcopy(self.d_len[:to_T])
        del self.d_len
        self.d_len = temp
        temp = copy.deepcopy(self.loss_inf[:to_T])
        del self.loss_inf
        self.loss_inf = temp
        temp = copy.deepcopy(self.w_len_inf[:to_T])
        del self.w_len_inf
        self.w_len_inf = temp
        temp = copy.deepcopy(self.wg_len_inf[:to_T])
        del self.wg_len_inf
        self.wg_len_inf = temp
        temp = copy.deepcopy(self.bg_len_inf[:to_T])
        del self.bg_len_inf
        self.bg_len_inf = temp
        gc.collect()

    def get_len_lst(self):
        return self.z_len, self.zh_len, self.h_len, self.d_len, \
                   self.loss_tr, self.wg_len, self.bg_len, \
                   self.loss_inf, self.wg_len_inf, self.bg_len_inf, \
                   self.z_len_tr, self.d_len_tr, \
                   self.w_len, self.w_len_inf
