from ast import arg
from cmath import log
from re import T
import wandb
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import random
from train.mlbase import MLBase
from data.anal_data import get_data
from tqdm import tqdm
from anal.util import length, length_w, get_w_grads, get_len_t,\
                                get_len_l, init_param, init_zs_db
from anal.log import p_wb, LenLog
from train.lr import warmup_lr, adjust_lr
from matplotlib import pyplot as plt


class AnalBase(MLBase):

    def __init__(self, p):
        super().__init__(other=p)
        self.T = self.args.T
        self.get_len = get_len_t if self.args.tensorized else get_len_l
        self.fwd = None
        self.best_acc = 0

    def __call__(self, i_run, sigma_w, sigma_b, eta):
        args = self.args; bsz = args.bsz
        self.lenlog = LenLog(args)
        with torch.no_grad():
            init_param(self.model, args, sigma_w, sigma_b)
            dt_dl = self.tr_dt_dl if args.cls else None
            ldt, tr_dl = get_data(args, dt_dl)
        zs_db = None if args.z_init != 'use_db' else init_zs_db(args, ldt)
        self.best_acc = 0
        for i_epoch in range(args.epochs):
            is_last = (i_epoch == args.epochs - 1); c, n = 0, 0
            self.model.train()
            for i_iter, dt in enumerate(tqdm(tr_dl)):
                if args.z_init == 'use_db' or args.dataset == 'random':
                    x, y, i_batch = dt
                else:
                    x, y = dt; i_batch = None
                # if i_iter == 1: break
                if args.len_all or is_last:
                    self.lenlog.inc('e')
                if args.train:
                    # adjust_lr(args, i_epoch, self.optimizer)
                    warmup_lr(args, i_epoch, tr_dl, i_iter, self.optimizer)
                x = x.to(args.device, non_blocking=True)
                y = y.to(args.device, non_blocking=True)
                if args.arch == 'fc':  x = x.reshape(args.bsz, -1)
                zs = self.init_zs(x, y, i_batch, zs_db)
                _c, _n = self.fwd_tr(zs, y, sigma_w, i_run, i_epoch, is_last,
                                     eta, zs_db, i_batch, i_iter)
                c += _c; n += _n
                if args.train and i_iter % 50 == 0:
                    if args.debug: p_wb(self.model)
                    print(f'training accuracy at epoch {i_epoch} iter {i_iter}:', 
                            c/n * 100)
                    with torch.no_grad():
                        self.eval_acc(args, sigma_w, sigma_b, i_run, i_epoch, i_iter)
                    c, n = 0, 0
        import time; time.sleep(3)
        self.lenlog.save_pkl(i_run, sigma_w, sigma_b, eta)
        time.sleep(5)
        return self.lenlog

    @torch.no_grad()
    def init_zs(self, x, y, i_batch=0, zs_db=None):
        args = self.args
        if args.z_init == 'use_db':
            zs = [zs_db[i][i_batch].detach() + 0.05 * \
                torch.normal(torch.zeros([args.bsz, args.ds[i+1]]), 1.0) \
                .to(args.device) for i in range(args.n_layers-2)]
        elif args.z_init == 'gaussian':
            zs = [torch.normal(torch.zeros([args.bsz, args.ds[i+1]]), 1.0) \
                .to(args.device) for i in range(args.n_layers-2)]
        elif args.z_init == 'ff':
            zs = self.model.feedfwd(x)[:-1]
        else:
            raise NotImplementedError
        zs.insert(0, x); zs.append(y)
        for z in zs:
            if z.isnan().any().item(): raise ValueError('NaN in zs')
        return zs

    def eval_acc(self, args, sigma_w, sigma_b, i_run, i_epoch, i_iter):
        step = self.lenlog.n * i_epoch + i_iter
        if args.verbose:
            print(f'\n{i_run}-th {sigma_w}, validation')
        self.model.eval()
        log_dt_vl = dict()
        log_dt_vl['metrics/acc'] = list(); log_dt_vl['metrics/loss'] = list()
        c, n= 0, 0
        for i_iter, (x, y) in enumerate(self.vl_dt_dl[1]):
            x = x.to(args.device, non_blocking=True)
            y = y.to(args.device, non_blocking=True)
            if args.arch == 'fc':  x = x.reshape(args.bsz_vl, -1)
            log_dt, _c, _n= self.fwd_vl(x, y)
            log_dt_vl['metrics/loss'].append(log_dt['loss/losses'][-1])
            c += _c; n += _n
        print('current loss:', log_dt['loss/losses'][-1])
        if args.cls:
            acc = c/n * 100
            self.best_acc = max(self.best_acc, acc)
            wandb.log({"val/step": step,
                      f'val/acc_r{i_run}_w{sigma_w}': acc,
                      f'val/acc_best_r{i_run}_w{sigma_w}': self.best_acc,
                      f'val/loss_r{i_run}_w{sigma_w}': \
                      np.mean(log_dt_vl['metrics/loss'])},
                       commit=True)
            wandb.log({f'log/epoch': i_epoch}, commit=True)
            print('validation accuracy:', acc)
            print('best validation accuracy:', self.best_acc)

    @torch.no_grad()
    def cnt_correct(self, output, target):
        bsz = target.shape[0]
        if self.args.last_cls:
            _, pred = output.topk(1, 1, True, True)
        else:
            mse_dist_sgn = -(output-1)**2
            _, pred = mse_dist_sgn.topk(1, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        return correct.reshape(-1).float().sum(0), float(bsz)
