import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from anal.util import length
from anal.reg import get_alpha, get_omega, get_beta
from model.models.pc import LocPC, ClsHead, ClsEmb, get_enc
from train.loss.util_loss import LabelSmoothingLoss


class Base(nn.Module):

    def __init__(self, args):
        super().__init__()
        self.args = args; self.iter_log = not args.not_iter_log
        self.L = len(args.ds)
        self.cls = ClsHead(args) if args.cls_mth=='z2y' else None
        self.emb = ClsEmb(args) if args.cls_mth=='emb' else None
        self.last_cls = args.last_cls
        self.half_L = int((self.L)/2.)
        if args.cls:
            self.loss_cls_smooth = LabelSmoothingLoss(args)
        self.loss_last = self.loss_cls if args.cls else self.loss

    def loss(self, prd, tg, f, is_last_z=False):
        # print(prd.shape); print(tg.shape)
        delta = tg - prd
        n = torch.linalg.norm(delta, dim=-1)
        n_sq = torch.square(n)
        loss = n_sq.sum() if self.args.loss_sum else n_sq.mean()
        if self.args.z_reg and not is_last_z:
            alpha = get_alpha(tg)
            alpha_ = alpha.sum() if self.args.loss_sum else alpha.mean()
            loss += self.args.reg_coef * alpha_
        if self.args.w_reg or self.args.b_reg:
            loss += f.get_reg()
        return loss, delta

    def loss_cls(self, prd, tg, f):
        s = 'sum' if self.args.loss_sum else 'mean'
        if self.args.smooth_label:
            loss = self.loss_cls_smooth(prd, tg)
        else:
            loss = F.cross_entropy(prd, tg, reduction=s)
        if self.args.w_reg and self.args.b_reg:
            loss += f.get_reg().mean()
        return loss, None

    def p(self, p_):
        return torch.bernoulli(torch.tensor(p_).float())

    def compare(self, zs, z_decs, z_encs):
        return self.compare_ls(zs, z_decs, z_encs)[0]


class PCD(Base):

    def __init__(self, args):
        super().__init__(args)
        pcd2 = args.method == 'pcd2'
        z_ff = args.z_init == 'ff'
        self.seq_pc = z_ff or pcd2 or not args.scon #not z_ff and not pcd2 #HJFIX
        self.l_io = self.half_L if not self.seq_pc else self.L-3
        self.l_io = self.l_io if args.scon else 0
            #self.half_L if not self.z_ff else self.L-3
        f_loc= [get_enc(args, l).to('cuda') for l in range(self.L-1)]
        f_in = [LocPC(args, self.L-2-l, x2z=True).to('cuda') for l in range(self.half_L)] \
                if not self.seq_pc  else None
        f_out = [LocPC(args, l, z2y=True).to('cuda') for l in range(self.l_io)]
        self.f_loc = nn.ModuleList(f_loc)
        self.f_in = nn.ModuleList(f_in) if not self.seq_pc else None
        self.f_out = nn.ModuleList(f_out)
        self.c2f = [False] + [self.args.lyr[l-1] == 'cnn' and self.args.lyr[l] == 'fc' \
                            for l in range(1, self.L-1)]

    def feedfwd(self, x):
        z = x
        zs = list()
        for l in range(self.L-1):
            z, _ = self.f_loc[l](z, self.c2f[l])
            zs.append(z)
        return zs

    def predict(self, lenlog, zs, is_last, is_last_e, check=False, vl=False):
        if is_last_e: self.iter_log = True
        zhls = list(); zhis = list(); zhos = list()
        hls = list(); his = list(); hos = list()
        if vl: zs[-1].requires_grad_()
        for l in range(self.L):
            if l > 0 and l < self.L-1:
                zs[l].requires_grad_()
            if l < self.L-1:
                zhl, hl = self.f_loc[l](zs[l], self.c2f[l])
                zhls.append(zhl); hls.append(hl.detach())
                if vl: zs.append(zhl.detach())
                if not check:
                    with torch.no_grad():
                        lenlog.set_zh(l,
                            [length(zs[l].view(zs[l].size(0), -1).detach()),
                            length(zhl.view(zs[l].size(0), -1).detach()),
                            length(hl.view(zs[l].size(0), -1).detach())])

            if l < self.l_io:
                if not self.seq_pc:
                    zhi, hi = self.f_in[l](zs[0].view(zs[0].size(0), -1)) #HJDEBUG
                    zhis.append(zhi); his.append(hi.detach())
                    # zhi, hi = self.f_in[l](zs[0]) #HJDEBUG
                    # zhis.append(zhi); his.append(hi.detach())
                zho, ho = self.f_out[l](zs[l+1], c2f=True)
                zhos.append(zho); hos.append(ho.detach())
        if not check:
            with torch.no_grad():
                lenlog.set_zh(self.L-1,[length(zs[l].detach())], True)
        return [zhls, zhis, zhos], None, [hls, his, hos]

    def compare_ls(self, lenlog, zs, is_last, deprecated, zhs, vl, y=None, trg_dtc=False,
                   check=False, ps=None, p_loc=1., p_io=1., prg=None):
        zhls, zhis, zhos = zhs; loss = 0; delta_lst = list(); cnt = 0.
        skip_w = self.args.skip_w
        for l in range(self.L-1):
            layer_loss = 0; delta = None
            bs = [self.p(p) for p in ps] if prg and not self.seq_pc else [self.p(p) for p in np.ones_like(ps)]
            if l < self.l_io: #in and out
                if not self.seq_pc: #HJDEBUG
                    loss += bs[l] * self.loss(zhis[l], zs[-(l+2)].view(zs[-(l+2)].size(0), -1), self.f_in[l])[0]
                loss += bs[l] * self.loss_cls(zhos[l], y, self.f_out[l])[0]
                cnt += 2 * bs[l]
            if l < self.L-2:
                _loss, delta = self.loss(zhls[l], zs[l+1].detach() if trg_dtc \
                                         else zs[l+1], self.f_loc[l])
                if l > 0:
                    _loss *= bs[l]; cnt += bs[l]
            else:
                _loss = self.loss_cls(zhls[-1], y, self.f_loc[-1])[0]
                delta = (zs[-1] - zhls[-1]); loss += _loss; cnt += 1
            loss += _loss; layer_loss = _loss.item()
            delta_lst.append(delta)
            if not check:
                lenlog.set_d(l, length(delta.view(delta.size(0), -1).detach()))
                lenlog.set_loss_inf(l+1, layer_loss)
        loss = 0.5 * loss
        if not self.args.loss_sum: loss /=cnt
        if self.args.w_reg or self.args.b_reg: loss /= (1.+self.args.reg_coef)
        return loss, delta_lst
