import gc
import os
import pickle
from this import d
import numpy as np
from anal.reg import get_omega, get_beta
from anal.util import get_ps
from data.dl_getter import nDT
from anal.util import transpose_np
from anal.lyap import lyap_r, lyap_e
from scipy.special import softmax
from scipy.stats import entropy


# print

def p_wb(model):
    l = 1
    w = model.f_loc[l].fc.weight
    b = model.f_loc[l].fc.bias
    wn = w.norm(); bn = b.norm()
    omega = get_omega(w)
    beta = get_beta(b)
    # print(omega.item(), beta.item(), wn.item(), bn.item())


name_lst = ['z_len', 'zh_len', 'h_len', 'd_len', 'loss', 'w_len', 'b_len']


class LenT:

    def __init__(self, len_np_lst, args, runs, sigma_ws, sigma_bs, etas):
        try:
            n_inner_iter = int(nDT[args.dataset]/args.bsz)
        except:
            n_inner_iter = 1
        self.args = args
        n_r = runs; n_w = len(sigma_ws); n_b = len(sigma_bs); n_e = len(etas)
        self.set_len(len_np_lst)
        self.i_r, self.i_w, self.i_b, self.i_et = -1,-1,-1,-1
        # r: run, w: sigma_w, b: sigma_b, et: eta

    def set_len(self, len_np_lst):
        self.z_len= len_np_lst[0]
        self.zh_len= len_np_lst[1]
        self.h_len= len_np_lst[2]
        self.d_len= len_np_lst[3]
        self.zds = self.z_len, self.zh_len, self.h_len, self.d_len
        self.loss_tr= len_np_lst[4]
        self.w_len= len_np_lst[5]
        self.b_len= len_np_lst[6]
        self.wbs = self.w_len, self.b_len, self.loss_tr
        self.loss_inf= len_np_lst[7]
        self.w_len_inf= len_np_lst[8]
        self.b_len_inf= len_np_lst[9]
        self.wbs_inf = self.w_len_inf, self.b_len_inf

    def get_stat(self):
        zd_stat = [(item[:, :, :, :, :, :-(i>0)].mean((0, -1)), \
                    item[:, :, :, :, :, :-(i>0)].std((0, -1)))
                    if i>0 else (item.mean((0, -1)), item.std((0, -1)))
                    for i, item in enumerate(self.zds)]
        zd_stat.append((self.loss_inf[:, :, :, :, :, 1:].mean(0),
                        self.loss_inf[:, :, :, :, :, 1:].std(0)))
        qpratio_stat = np.divide(self.d_len[:,:,:,:,:,:-1],
                                 self.z_len[:,:,:,:,:,1:]+1e-9)
        zd_stat.append((qpratio_stat.mean((0, -1)), qpratio_stat.std((0, -1))))
        wb_stat = [(item.mean(0), item.std(0)) for item in self.wbs]
        wb_stat_inf = [(item.mean(0), item.std(0)) for item in self.wbs_inf]
        return zd_stat, wb_stat, wb_stat_inf

    def get_stat_plot(self):
        zd_stat_lst, wb_stat_lst, wb_inf_stat_lst = self.get_stat()
        # 44-45 spits error
        zd = transpose_np([(zd_stat_lst[0][0][:,:,:,:,1:],
                            zd_stat_lst[0][1][:,:,:,:,1:])]
                        +zd_stat_lst[1:-2]) # excluding loss_inf and qpratio
        wb = transpose_np(wb_stat_lst[:-1])
        wb_inf = transpose_np(wb_inf_stat_lst)
        return zd_stat_lst, wb_stat_lst, wb_inf_stat_lst, zd, wb, wb_inf


    # def set_zh(self, i_l, vals, is_L=False):
    #     self.z_len[self.i_r, self.i_w, self.i_b, self.i_et, self.i_i, i_l, :] = vals[0]
    #     if not is_L:
    #         self.zh_len[self.i_r, self.i_w, self.i_b, self.i_et, self.i_i, i_l, :] = vals[1]
    #         self.h_len[self.i_r, self.i_w, self.i_b, self.i_et, self.i_i, i_l, :] = vals[2]

    # def set_d(self, i_l, val):
    #     self.d_len[self.i_r, self.i_w, self.i_et, self.i_b, self.i_i, i_l-1, :] = val

    # def set_loss_inf(self, i_l, loss):
    #     self.loss_inf[self.i_r, self.i_w, self.i_et, self.i_b, self.i_i, i_l] = loss

    # def set_wb(self, w_len, b_len, loss):
    #     self.w_len[self.i_r, self.i_w, self.i_b, self.i_et, self.i_ei, :] = w_len
    #     self.b_len[self.i_r, self.i_w, self.i_b, self.i_et, self.i_ei, :] = b_len
    #     self.loss_tr[self.i_r, self.i_w, self.i_b, self.i_et, self.i_ei, :] = loss

    # def set_wb_inf(self, w_len, b_len):
    #     self.w_len_inf[self.i_r, self.i_w, self.i_b, self.i_et, self.i_i, :] = w_len
    #     self.b_len_inf[self.i_r, self.i_w, self.i_b, self.i_et, self.i_i, :] = b_len

    # def save_pkl(self, i_run, sigma_w, sigma_b, eta):
    #     len_lst = [self.z_len, self.zh_len, self.h_len, self.d_len, self.loss_tr, \
    #                self.loss_inf, self.w_len, self.b_len]
    #     s = f"{i_run},{sigma_w},{sigma_b},{eta}"
    #     log_path = os.path.join(self.args.log_root, s)
    #     with open(log_path,"wb") as fw:
    #         pickle.dump(len_lst, fw)

    # def inc(self, which=''):
    #     if which == 'r':
    #         self.i_r += 1
    #         self.i_w, self.i_b, self.i_et, self.i_ei, self.i_i = -1,-1,-1,-1,-1
    #     elif which == 'w':
    #         self.i_w += 1
    #         self.i_b, self.i_et, self.i_ei, self.i_i = -1,-1,-1,-1
    #     elif which == 'b':
    #         self.i_b += 1
    #         self.i_et, self.i_ei, self.i_i = -1,-1,-1
    #     elif which == 'et':
    #         self.i_et += 1
    #         self.i_ei, self.i_i = -1,-1
    #     elif which == 'e':
    #         self.i_ei += 1
    #         self.i_i = -1
    #     elif which == 'i':
    #         self.i_i += 1
    #     else:
    #         raise NotImplementedError



class LenLog:

    def __init__(self, args):
        try:
            n_inner_iter = int(nDT[args.dataset]/args.bsz)
        except:
            n_inner_iter = 1
        self.args = args
        self.L = args.n_layers; self.bsz = args.bsz
        zd_lst = [np.zeros([args.T, self.L, self.bsz]) for _ in range(4)]
        self.zds = self.z_len, self.zh_len, self.h_len, self.d_len = zd_lst
        self.loss_inf = np.zeros([args.T, self.L])
        n_epochs =  args.epochs if args.len_all  else 1
        wb_lst = [np.zeros([n_epochs * n_inner_iter, self.L-1]) for _ in range(3)]
        self.wbs = self.w_len, self.b_len, self.loss_tr = wb_lst
        wb_lst_inf = [np.zeros([args.T, self.L-1]) for _ in range(2)]
        self.wbs_inf = self.w_len_inf, self.b_len_inf = wb_lst_inf
        self.i_ei, self.i_i = -1,-1
        self.n = n_inner_iter
        # ei: epoch & batch_iter, i:inference_iter, l: layer

    def set_zh(self, i_l, vals, is_L=False):
        self.z_len[self.i_i, i_l, :] = vals[0]
        if not is_L:
            self.zh_len[self.i_i, i_l, :] = vals[1]
            self.h_len[self.i_i, i_l, :] = vals[2]

    def set_d(self, i_l, val):
        self.d_len[self.i_i, i_l-1, :] = val

    def set_loss_inf(self, i_l, loss):
        self.loss_inf[self.i_i, i_l] = loss

    def set_wb(self, w_len, b_len, loss):

        self.w_len[self.i_ei, :] = w_len
        self.b_len[self.i_ei, :] = b_len
        self.loss_tr[self.i_ei, :] = loss

    def set_wb_inf(self, w_len, b_len):
        self.w_len_inf[self.i_i, :] = w_len
        self.b_len_inf[self.i_i, :] = b_len

    def save_pkl(self, i_run, sigma_w, sigma_b, eta):
        # len_lst = [self.z_len, self.zh_len, self.h_len, self.d_len, self.loss_inf, \
        #            self.w_len_inf, self.b_len_inf, self.loss_tr, self.w_len, self.b_len]
        s = f"{i_run},{sigma_w},{sigma_b},{eta}"
        log_path = os.path.join(self.args.log_dir, str(self.args.T), s)
        if not os.path.exists(os.path.join(self.args.log_dir, str(self.args.T))):
            os.makedirs(os.path.join(self.args.log_dir, str(self.args.T)))
        with open(log_path,"wb") as fw:
            pickle.dump(self, fw) # pickle.dump(len_lst, fw)

    def get_stat(self):
        zd_stat = [(item[:, :-(i>0)].mean((0, -1)), \
                    item[:, :-(i>0)].std((0, -1)))
                    if i>0 else (item.mean((0, -1)), item.std((0, -1)))
                    for i, item in enumerate(self.zds)]
        zd_stat.append((self.loss_inf[:, 1:].mean(0),
                        self.loss_inf[:, 1:].std(0)))
        qpratio_stat = np.divide(self.d_len[:,:-1],
                                 self.z_len[:,1:]+1e-9)
        zd_stat.append((qpratio_stat.mean((0, -1)), qpratio_stat.std((0, -1))))
        wb_stat = [(item.mean(0), item.std(0)) for item in self.wbs]
        wb_stat_inf = [(item.mean(0), item.std(0)) for item in self.wbs_inf]
        return zd_stat, wb_stat, wb_stat_inf

    def get_lyap(self, l):
        z_series = self.z_len[:, l, :]
        z_series = np.transpose(z_series, (1, 0))
        z_lyap_r = lyap_r(z_series, emb_dim=3, lag=1, min_tsep=20)
        d_series = self.d_len[:, l, :]
        d_series = np.transpose(d_series, (1, 0))
        d_lyap_r = lyap_r(d_series, emb_dim=3, lag=1, min_tsep=20)
        # z_lyap_e = lyap_e(z_series, emb_dim=6, matrix_dim=2)
        return z_lyap_r.mean(), z_lyap_r.std(), d_lyap_r.mean(), d_lyap_r.std()

    def cal_lyap(self, series):
        # z_series = np.transpose(series, (1, 0))
        lyap_r = lyap_r(series, emb_dim=3, lag=1, min_tsep=20)

    def get_dist_metric(self):
        # variance
        d = self.d_len[-1, :-1, :]
        d_std = np.std(d, axis=0, ddof=1)
        d_mean = np.mean(d, axis=0)
        d_coef_var = d_std / d_mean
        # entropy
        np.set_printoptions(precision=5)
        p = softmax(d, axis=0)
        ent = entropy(p, base=2, axis=0)
        # id_p = np.where(p != 0)
        # import ipdb; ipdb.set_trace()
        # p_nz = p[id_p].view(p.shape())
        # entropy = -np.sum(p_nz*np.log(p_nz), axis=0)
        # skew, kurtosis
        # from scipy.stats import kurtosis, skew
        # kurt = kurtosis(d, axis=0)
        # skewd = skew(d, axis=0)
        return d_std.mean(), d_std.std(), d_coef_var.mean(), d_coef_var.std(),\
              ent.mean(), ent.std() #, kurt.mean(), kurt.std(), \
              # skewd.mean(), skewd.std()



    def inc(self, which=''):
        if which == 'e':
            self.i_ei += 1
            self.i_i = -1
        elif which == 'i':
            self.i_i += 1
        else:
            raise NotImplementedError

    def clip_T(self, to_T):
        import copy
        temp = copy.deepcopy(self.z_len[:to_T])
        del self.z_len
        self.z_len = temp
        temp = copy.deepcopy(self.zh_len[:to_T])
        del self.zh_len
        self.zh_len = temp
        temp = copy.deepcopy(self.h_len[:to_T])
        del self.h_len
        self.h_len = temp
        temp = copy.deepcopy(self.d_len[:to_T])
        del self.d_len
        self.d_len = temp
        temp = copy.deepcopy(self.loss_tr[:to_T])
        del self.loss_tr
        self.loss_tr = temp
        temp = copy.deepcopy(self.w_len[:to_T])
        del self.w_len
        self.w_len = temp
        temp = copy.deepcopy(self.b_len[:to_T])
        del self.b_len
        self.b_len = temp
        temp = copy.deepcopy(self.loss_inf[:to_T])
        del self.loss_inf
        self.loss_inf = temp
        temp = copy.deepcopy(self.w_len_inf[:to_T])
        del self.w_len_inf
        self.w_len_inf = temp
        temp = copy.deepcopy(self.b_len_inf[:to_T])
        del self.b_len_inf
        self.b_len_inf = temp
        gc.collect()
        for i in range(len(self.zds)):
            self.zds[i] = self.zds[i][:to_T]
        for i in range(len(self.wbs_inf)):
            self.wbs_inf[i] = self.wbs_inf[i][:to_T]

    def get_len_lst(self):
        return self.z_len, self.zh_len, self.h_len, self.d_len, \
                   self.loss_tr, self.w_len, self.b_len, \
                   self.loss_inf, self.w_len_inf, self.b_len_inf
