from ast import arg
from cmath import log
from re import T
import wandb
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import random
from train.mlbase import MLBase
from data.anal_data import get_data
from tqdm import tqdm
from anal.util import length, length_w, get_w_grads, get_len_t,\
                                get_len_l, init_param, init_zs_db
from anal.log import p_wb, LenLog
from train.lr import warmup_lr, adjust_lr
from matplotlib import pyplot as plt


class AnalBase(MLBase):

    def __init__(self, p):
        super().__init__(other=p)
        self.T = self.args.T
        self.get_len = get_len_t if self.args.tensorized else get_len_l
        self.fwd = None
        self.best_acc = 0

    def __call__(self, i_run, sigma_w, sigma_b, eta):
        args = self.args; bsz = args.bsz
        self.lenlog = LenLog(args)
        with torch.no_grad():
            init_param(self.model, args, sigma_w, sigma_b)
            dt_dl = self.tr_dt_dl if args.cls else None
            ldt, tr_dl = get_data(args, dt_dl)
        zs_db = None if args.z_init != 'use_db' else init_zs_db(args, ldt)
        self.best_acc = 0
        for i_epoch in range(args.epochs):
            is_last = (i_epoch == args.epochs - 1); c, n = 0, 0
            self.model.train()
            util_dist_std_m = 0
            util_dist_coef_var_m = 0
            util_dist_entropy_m = 0
            util_dist_std_s = 0
            util_dist_coef_var_s = 0
            util_dist_entropy_s = 0
            util_cnt = 0
            
            util_z_lyap_m = 0
            util_z_lyap_s = 0
            util_d_lyap_m = 0
            util_d_lyap_s = 0
            util_cnt2 = 0


            for i_iter, dt in enumerate(tqdm(tr_dl)):
                if args.z_init == 'use_db' or args.dataset == 'random':
                    x, y, i_batch = dt
                else:
                    x, y = dt; i_batch = None
                if i_iter == 1: break
                if args.len_all or is_last:
                    self.lenlog.inc('e')
                if args.train:
                    # adjust_lr(args, i_epoch, self.optimizer)
                    warmup_lr(args, i_epoch, tr_dl, i_iter, self.optimizer)
                x = x.to(args.device, non_blocking=True)
                y = y.to(args.device, non_blocking=True)
                if args.arch == 'fc':  x = x.reshape(args.bsz, -1)
                zs = self.init_zs(x, y, i_batch, zs_db)
                _c, _n = self.fwd_tr(zs, y, sigma_w, i_run, i_epoch, is_last,
                                     eta, zs_db, i_batch, i_iter)
                c += _c; n += _n
                if args.train and i_iter % 50 == 0:
                    if args.debug: p_wb(self.model)
                    print(f'training accuracy at epoch {i_epoch} iter {i_iter}:',
                            c/n * 100)
                    with torch.no_grad():
                        self.eval_acc(args, sigma_w, sigma_b, i_run, i_epoch, i_iter)
                    c, n = 0, 0
                if (not args.train) and args.lyap:
                    l = 15 if args.dataset == 'random' else 27
                    z_lyap_m, z_lyap_s, d_lyap_m, d_lyap_s = self.lenlog.get_lyap(l)
                    tmp_bsz2 = x.size(0)
                    util_z_lyap_m += z_lyap_m * tmp_bsz2
                    util_z_lyap_s +=  z_lyap_s * tmp_bsz2
                    util_d_lyap_m += d_lyap_m * tmp_bsz2
                    util_d_lyap_s += d_lyap_s * tmp_bsz2
                    util_cnt2 += tmp_bsz2
                    
                    #TODO: collect the results of lyap
                if is_last:
                    std_m, std_s, coef_var_m, coef_var_s, entropy_m, entropy_s = self.lenlog.get_dist_metric()
                    tmp_bsz = x.size(0)
                    util_dist_std_m += std_m * tmp_bsz
                    util_dist_coef_var_m += coef_var_m * tmp_bsz
                    util_dist_entropy_m += entropy_m * tmp_bsz

                    util_dist_std_s += std_s * tmp_bsz
                    util_dist_coef_var_s += coef_var_s * tmp_bsz
                    util_dist_entropy_s += entropy_s * tmp_bsz

                    util_cnt += tmp_bsz

        self.args.util_dist_std_m = util_dist_std_m / util_cnt
        self.args.util_dist_coef_var_m = util_dist_coef_var_m / util_cnt
        self.args.util_dist_entropy_m = util_dist_entropy_m / util_cnt

        self.args.util_dist_std_s = util_dist_std_s / util_cnt
        self.args.util_dist_coef_var_s = util_dist_coef_var_s / util_cnt
        self.args.util_dist_entropy_s = util_dist_entropy_s / util_cnt

        
        self.args.util_z_lyap_m = util_z_lyap_m / util_cnt2
        self.args.util_z_lyap_s = util_z_lyap_s / util_cnt2
        self.args.util_d_lyap_m = util_d_lyap_m / util_cnt2
        self.args.util_d_lyap_s = util_d_lyap_s / util_cnt2

        self.lenlog.save_pkl(i_run, sigma_w, sigma_b, eta)

        return self.lenlog

    @torch.no_grad()
    def init_zs(self, x, y, i_batch=0, zs_db=None):
        args = self.args

        zs_ff = self.model.feedfwd(x)[:-1]
        zs_ff = [z.detach() for z in zs_ff]
        if args.z_init == 'ff':
            zs = zs_ff
        elif args.z_init == 'use_db':
            zs = [(zs_db[i][i_batch].detach() +
                0.05 * torch.normal(mean=0.0, std=1.0, size=zs_db[i][i_batch].detach().size()).to(args.device)).view(zs_ff[i].size())
                for i in range(args.n_layers-2)]
        elif args.z_init == 'gaussian':
            zs = [torch.normal(mean=0.0, std=1.0, size=zs_ff[i].size()) \
                    .to(args.device) for i in range(args.n_layers-2)]
        else:
            raise NotImplementedError

        zs.insert(0, x); zs.append(y)
        for z in zs:
            if z.isnan().any().item(): raise ValueError('NaN in zs')
        return zs

    def eval_acc(self, args, sigma_w, sigma_b, i_run, i_epoch, i_iter):
        step = self.lenlog.n * i_epoch + i_iter
        if args.verbose:
            print(f'\n{i_run}-th {sigma_w}, validation')
        self.model.eval()
        log_dt_vl = dict()
        log_dt_vl['metrics/acc'] = list(); log_dt_vl['metrics/loss'] = list()
        c, n= 0, 0
        for i_iter, (x, y) in enumerate(self.vl_dt_dl[1]):
            x = x.to(args.device, non_blocking=True)
            y = y.to(args.device, non_blocking=True)
            if args.arch == 'fc':  x = x.reshape(args.bsz_vl, -1)
            log_dt, _c, _n= self.fwd_vl(x, y)
            log_dt_vl['metrics/loss'].append(log_dt['loss/losses'][-1])
            c += _c; n += _n
        print('current loss:', log_dt['loss/losses'][-1])
        if args.cls:
            acc = c/n * 100
            self.best_acc = max(self.best_acc, acc)
            wandb.log({"val/step": step,
                      f'val/acc_r{i_run}_w{sigma_w}': acc,
                      f'val/acc_best_r{i_run}_w{sigma_w}': self.best_acc,
                      f'val/loss_r{i_run}_w{sigma_w}': \
                      np.mean(log_dt_vl['metrics/loss'])},
                       commit=True)
            wandb.log({f'log/epoch': i_epoch}, commit=True)
            print('validation accuracy:', acc)
            print('best validation accuracy:', self.best_acc)
            #HJ TRICK FOR LOGGING
            args.util_best_acc = self.best_acc
            args.util_acc = acc

    @torch.no_grad()
    def cnt_correct(self, output, target):
        bsz = target.shape[0]
        if self.args.last_cls:
            _, pred = output.topk(1, 1, True, True)
        else:
            mse_dist_sgn = -(output-1)**2
            _, pred = mse_dist_sgn.topk(1, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        return correct.reshape(-1).float().sum(0), float(bsz)
