from re import T
import math
import wandb
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import random
from train.mlbase import MLBase
from model.models.classifier import Classifier
from data.anal_data import get_data
from anal.base import AnalBase
from anal.util import length, length_w, get_w_grads, get_len_t,\
                                get_len_l, set_ps
from anal.log import p_wb, LenLog
from anal.reg import wb_norm, get_coef_vec, get_coef_mat  # , reg_scale
from matplotlib import pyplot as plt
from anal.debug import check_z, check_d_rel, check_w_grad


class AnalPC(AnalBase):

    def __init__(self, p):
        super().__init__(p)

    def fwd_tr(self, zs, y, run_cond_v, i_run, i_epoch, is_last_e, eta, \
                zs_db=None, i_batch=None, i_iter=None, vl=False):
        args = self.args
        self.eta_st = eta
        L = len(zs) - 1
        ds, zhs, loss = self.pc(zs, y, i_epoch, is_last_e, zs_db,
                                              i_batch, vl)
        with torch.no_grad():
            ds = [d.detach() for d in ds]
            zhs = [zh_enc.detach() for zh_enc in zhs]

            #HJ TRICK FOR LOGGING
            if is_last_e:
                dlengths = [length(d.view(d.size(0), -1)).tolist() for d in ds]
                zhlengths = [length(zh.view(zh.size(0), -1)).tolist() for zh in zhs]

                if 'util_dlengths_train' not in args.__dict__ : args.util_dlengths_train = []
                if not i_epoch < len(args.util_dlengths_train):
                    args.util_dlengths_train.append(dlengths)
                else :
                    args.util_dlengths_train[i_epoch] = np.concatenate([args.util_dlengths_train[i_epoch], dlengths], axis=1).tolist()

                if 'util_zhlengths_train' not in args.__dict__ : args.util_zhlengths_train = []
                if not i_epoch < len(args.util_zhlengths_train):
                    args.util_zhlengths_train.append(zhlengths)
                else :
                    args.util_zhlengths_train[i_epoch] = np.concatenate([args.util_zhlengths_train[i_epoch], zhlengths], axis=1).tolist()


            if args.len_all or is_last_e or args.v_check:
                w_grads, wg_len, bg_len, w_len, b_len = get_w_grads(self.model)
                self.lenlog.set_wb(wg_len, bg_len, loss) #HJTEMP
            # for l in range(args.n_layers-1):
            #     last = (l == args.n_layers-2)
            #     wandb.log({ # f'len/z_{l}': log_dt[f'len/z_len_{l}'],
            #                f'len/w_grad_{l}': wg_len[l],
            #                f'len/b_grad_{l}': bg_len[l],
            #                f'len/w_{l}': w_len[l],
            #                f'len/b_{l}': b_len[l]},
            #                commit=last)
            if args.v_check:
                check_z(zs, zhs, ds, self.lenlog, i_run, run_cond_v)
                check_d_rel(self.model, self.optimizer, zs, L,  i_run,
                            run_cond_v)
                check_w_grad(w_grads, zs, ds, i_run, run_cond_v)
            if args.cls:
                c, n = self.cnt_correct(zhs[-1], y)
                return c.item(), n
            else: return 0, 1

    def pc(self, zs, y, i_epoch, is_last_e, zs_db=None, i_batch=None, vl=False):
        args = self.args; model = self.model; zhs = None
        progress = min(float(i_epoch)/float(args.epochs) *1.5, 1)
        T = args.T if not vl else args.T_vl
        if args.step_T: T = int(0.2 * args.T + 0.8 * progress * args.T)
        if args.cls:
            y_oh = F.one_hot(y, args.n_cls).float(); zs[-1] = y_oh.detach()
        L = len(zs) - 1
        half_L = int((len(zs) + 1)/2.)
        ps = set_ps(progress, half_L, L)
        self.eta = self.eta_st; target_eta = self.eta / args.eta_r
        eta_sz = (self.eta - target_eta) / T
        n_iter = T + 1 if args.train else T
        loss_prev = 1e10; cnt = 0
        for i in range(n_iter):
            if args.len_all or is_last_e: self.lenlog.inc('i')
            if args.train and args.step_eta: self.eta -= eta_sz
            # print(self.eta)
            is_last_t = (i == T-1)
            check=(i == T) or (not args.len_all and not is_last_e)
            with torch.enable_grad():
                zhs, _, hs = self.model.predict(self.lenlog, zs, is_last_t, is_last_e,
                                                vl=vl, check=(i == T))
                loss, ds = self.model.compare_ls(self.lenlog, zs, is_last_t, None, zhs,
                                                vl, y, ps=ps, check=(i == T), prg=args.prg)
            if args.pcd: zhs = zhs[0]; hs = hs[0]
            self.optimizer.zero_grad()
            loss.backward()
            # wnb style scheduler
            if args.train and args.comp_eta:
                if loss.item() > loss_prev and cnt < args.eta_cnt:
                    self.eta /= args.step_inf_eta; cnt += 1
                elif abs(loss.item() - loss_prev) < args.inf_tol*self.eta:
                    break
                # if i>0: import ipdb; ipdb.set_trace()
                loss_prev = loss.item()
            with torch.no_grad():
                if not check:
                    w_grads, wg_len, bg_len, w_len, b_len = get_w_grads(self.model)
                    self.lenlog.set_wb_inf(wg_len, bg_len)
            if i < T: zs = self.update_z(zs, L)
        if args.train:
            if args.w_norm or args.b_norm: wb_norm(args, self.model)
            self.optimizer.step()
            if args.z_init == 'use_db':
                with torch.no_grad():
                    for i in range(args.n_layers-2): #HJCNNUSEDB
                        zs_db[i][i_batch] = zs[i+1].detach().view(zs_db[i][i_batch].size())
        return ds, zhs, loss.item()

    def fwd_vl(self, x, y): #HJMARK
        args = self.args; model = self.model
        log_dt = dict(); log_dt_plt = dict()
        log_dt_plt['loss/losses'] = list()
        z = self.model.feedfwd(x)[-1]
        loss = F.cross_entropy(z, y) if args.last_cls \
            else F.mse_loss(z, F.one_hot(y, args.n_cls).float())
        with torch.no_grad():
            log_dt_plt['loss/losses'].append(loss.item())
            if args.cls:
                c, n = self.cnt_correct(z, y)
                return log_dt_plt, c.item(), n
            else: return log_dt_plt, 0, 1

    def update_z(self, zs, L):
        args = self.args
        if args.z_norm:
            with torch.no_grad():
                for i in range(1, len(zs)-1):
                    if zs[i].grad.isnan().any(): raise ValueError('NaN in z grad')
                    zs[i].grad = zs[i].grad.detach() * get_coef_vec(args, zs[i])
        zs = [(zs[i].detach() - self.eta * zs[i].grad.detach()).detach() \
            if i > 0 and i < L else zs[i].detach() for i in range(len(zs))]
        if args.z_hard_norm:
            zs = [F.normalize(zs[i], dim=-1) for i in range(len(zs))]
        return zs
