import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from anal.util import length
from anal.reg import get_alpha
from model.models.pc import LocConv, LocFC, ClsHead, ClsEmb, get_enc
from model.models.pc_dense import Base


class PCL(Base):

    def __init__(self, args):
        super().__init__(args)
        f_loc= [get_enc(args, l).to('cuda') for l in range(self.L-1)]
        self.f_loc = nn.ModuleList(f_loc)
        self.c2f = [False] + [self.args.lyr[l-1] == 'cnn' and self.args.lyr[l] == 'fc' \
                            for l in range(1, self.L-1)]

    def feedfwd(self, x):
        z = x
        zs = list()
        for l in range(self.L-1):
            z, _ = self.f_loc[l](z, self.c2f[l])
            zs.append(z)
        return zs

    def predict(self, lenlog, zs, is_last, is_last_e, check=False, vl=False):
        if is_last_e: self.iter_log = True
        zhls = list(); hls = list()
        if vl: zs[-1].requires_grad_()
        for l in range(self.L):
            if l > 0 and l < self.L-1:
                zs[l].requires_grad_()
            if l < self.L-1:
                zhl, hl = self.f_loc[l](zs[l], self.c2f[l])
                zhls.append(zhl); hls.append(hl.detach())
                if not check:
                    with torch.no_grad():
                        lenlog.set_zh(l, 
                            [length(zs[l].view(zs[l].size(0), -1).detach()),
                            length(zhl.view(zs[l].size(0), -1).detach()), 
                            length(hl.view(zs[l].size(0), -1).detach())])
        if not check:
            with torch.no_grad():
                lenlog.set_zh(self.L-1,[length(zs[l].detach())], True)
        return zhls, None, hls

    def compare_ls(self, lenlog, zs, is_last, deprecated, zhls, vl, y=None, trg_dtc=False,
                   check=False, ps=None, p_loc=1., p_io=1., prg=None):
        loss = 0; delta_lst = list(); cnt = 0.
        for l in range(self.L-1):
            layer_loss = 0; delta = None
            bs = [self.p(p) for p in ps]
            if l < self.L-2:
                _loss, delta = self.loss(zhls[l], zs[l+1].detach() if trg_dtc \
                                         else zs[l+1], self.f_loc[l])
                if l > 0: _loss *= bs[l]; cnt += bs[l]
                else: cnt += bs[l]
            else:
                # _loss = self.loss(zhls[-1], y, self.f_loc[-1])[0]
                _loss = self.loss_cls(zhls[-1], y, self.f_loc[-1])[0]
                delta = (zs[-1] - zhls[-1]); loss += _loss; cnt += 1
            loss += _loss; layer_loss = _loss.item()
            delta_lst.append(delta)
            if not check:
                lenlog.set_d(l, length(delta.view(delta.size(0), -1).detach()))
                lenlog.set_loss_inf(l+1, layer_loss)
        loss = 0.5 * loss
        if not self.args.loss_sum: loss /=cnt
        if self.args.w_reg or self.args.b_reg: loss /= (1.+self.args.reg_coef)
        return loss, delta_lst
