import wandb
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import random
from model.models.classifier import Classifier
from data.anal_data import get_data
from anal.base import AnalBase
from anal.util import length, init_param, get_w_grads


class AnalFF(AnalBase):

    def __init__(self, p, lenlog):
        super().__init__(p, lenlog)
        self.fwd = self._fwd
        self.fwd_val = self._fwd
        self.criterion = nn.CrossEntropyLoss()

    def _fwd(self, zs, y, run_cond_v, i_run, i_epoch, is_last, \
                zs_db=None, i_batch=None, i_iter=None, vl=False):
        if not vl:
            self.optimizer.zero_grad()
        y_hat, zs = self.model.forward_layer(zs[0])
        loss = self.criterion(y_hat, y)
        if not vl:
            loss.backward()
            self.optimizer.step()
        log_dt_plt = dict()
        with torch.no_grad():
            _, w_len, b_len = get_w_grads(self.model)
            log_dt_plt['len/w_len'] = w_len
            z_len = [length(z) for z in zs]
            log_dt_plt['loss/losses'] = [loss.item()]
            log_dt_plt['len/z_len'] = z_len
            log_dt_plt['len/b_len'] = b_len
            if self.args.cls:
                c, n = self.cnt_correct(zs[-1], y)
                return log_dt_plt, c.item(), n, 0
        return log_dt_plt, 0, 1, 0

    def store_results(self, log_dt_plt):
        len_lst = [log_dt_plt['len/z_len'], log_dt_plt['len/z_len'],
                    log_dt_plt['len/z_len'], None, log_dt_plt['len/w_len'],
                    log_dt_plt['len/b_len'], None, None]
        loss_lst = [log_dt_plt['loss/losses'], None]
        p_lst = [None, None, None, None, None]
        return len_lst, loss_lst, p_lst
