# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from anal.util import length
from anal.reg import get_alpha, get_omega, get_beta
from train.loss.util_loss import LabelSmoothingLoss


nonlinearity = {
    'relu': nn.ReLU(),
    'tanh': torch.tanh,
    'leaky_relu': nn.LeakyReLU(),
    'selu': nn.SELU(),
    'silu': nn.SiLU(),
    'linear': nn.Identity(),
    'logsig': nn.LogSigmoid(),
}
# get_enc = lambda args, l: LocConv(args, l) if args.cnn else LocFC(args, l)
# get_dec = lambda args, l: (LocConv(args, l, enc=False)) if args.cnn else (LocFC(args, l+1, enc_f=False))

get_enc = lambda args, l: LocPC(args, l)
get_dec = lambda args, l: (LocConv(args, l, enc=False)) if args.cnn else (LocFC(args, l+1, enc_f=False))

def linear(in_d, out_d, bias=True):
    return nn.Linear(in_d, out_d, bias=bias)


def conv2d(in_d, out_d, k_size, stride, padding):
    return nn.Conv2d(in_d, out_d, kernel_size=k_size, stride=stride,
                    padding=padding)


class LocPC(nn.Module):

    def __init__(self, args, l, x2z=False, z2y=False, enc_f=True, cls=False, o2c=False):
        super().__init__()
        self.args = args
        if x2z:
            l1, l2 = (0, l)
        elif z2y:
            l1, l2 = (l+1, -1)
        else:
            l1, l2 = (l, l+1) if enc_f else (l, l-1)
        
        if x2z:
            self.layer = linear(args.ds[l1], args.ds[l2], bias=not args.skip_bias)
        else:
            self.layer = linear(args.ds[l1], args.ds[l2], bias=not args.skip_bias)\
                     if args.lyr[l] == 'fc' or z2y else \
                     conv2d(args.ecs[l], args.ecs[l+1], args.eks[l], \
                            args.ess[l], args.eps[l])
        self.act = nn.Identity() if z2y or l == len(args.ds)-2 \
                else nonlinearity[args.act]
        self.skip_con = args.skip_con and args.ds[l1] == args.ds[l2]
        self.dropout = nn.Dropout(p=args.dropout)
        self.do_drop = args.do_drop
        self.L = len(args.ds) - 1
        self.l = l
        
        self.transform = (self.layer_no_act if l == 0 else self.act_layer) \
                         if args.act_first else self.layer_act
        # torch.nn.init.kaiming_normal_(self.layer.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, z, c2f=False):
        z_hat, h = self.transform(z, c2f)
        if self.do_drop and self.l != self.L-1: # except for the last layer
            z_hat = self.dropout(z_hat)
        if self.skip_con: z_hat += z
        return z_hat, h

    def act_layer(self, z, c2f=False):
        h = self.act(z)
        if c2f: h = h.view(h.size(0), -1)
        z_hat = self.layer(h)
        return z_hat, h

    def layer_act(self, z, c2f=False):
        if c2f: z = z.view(z.size(0), -1)
        h = self.layer(z)
        z_hat = self.act(h)
        return z_hat, h

    def layer_no_act(self, z, c2f=False):
        if c2f: z = z.view(z.size(0), -1)
        z_hat = h = self.layer(z)
        return z_hat, h

    def get_reg(self):
        l=0
        if self.args.w_reg:
            omega = get_omega(self.layer.weight) 
            l += omega.sum() if self.args.loss_sum else omega.mean()
        if (not self.args.skip_bias) and self.args.b_reg:
                len_b = get_beta(self.layer.bias)
                l += len_b.sum() if self.args.loss_sum else len_b.mean()
        return self.args.reg_coef * l


class LocFC(nn.Module): # local encoder

    def __init__(self, args, l, x2z=False, z2y=False, enc_f=True, cls=False, o2c=False):
        super().__init__()
        self.args = args
        if x2z:
            l1, l2 = (0, l)
        elif z2y:
            l1, l2 = (l+1, -1)
        else:
            l1, l2 = (l, l+1) if enc_f else (l, l-1)
        if not o2c:
            self.fc = nn.Linear(args.ds[l1], args.ds[l2], bias=not args.skip_bias)
        else:
            self.fc = nn.Linear(args.ds_o[l1], args.ds_o[l2])
        self.act = nn.Identity() if z2y or l == len(args.ds)-2 \
                else nonlinearity[args.act]
        self.skip_con = args.skip_con and args.ds[l1] == args.ds[l2]
        self.dropout = nn.Dropout(p=args.dropout)
        self.do_drop = args.do_drop
        self.L = len(args.ds) - 1
        self.l = l
        self.transform = (self.fc_f if l == 0 else self.act_fc) if args.act_first else self.fc_act
        # torch.nn.init.kaiming_normal_(self.fc.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, z):
        # h = self.fc(z)
        # z_hat = self.act(h)
        z_hat, h = self.transform(z)
        # except for the last layer
        if self.do_drop and self.l != self.L-1:
            z_hat = self.dropout(z_hat)
        if self.skip_con: z_hat += z
        return z_hat, h

    def act_fc(self, z):
        h = self.fc(z)
        z_hat = self.act(h)
        return z_hat, h

    def fc_act(self, z):
        h = self.act(z)
        z_hat = self.fc(h)
        return z_hat, h

    def fc_f(self, z):
        z_hat = h = self.fc(z)
        return z_hat, h

    def get_reg(self):
        l=0
        if self.args.w_reg:
            omega = get_omega(self.fc.weight)
            l += omega.sum() if self.args.loss_sum else omega.mean()
        if (not self.args.skip_bias) and self.args.b_reg:
                len_b = get_beta(self.fc.bias)
                l += len_b.sum() if self.args.loss_sum else len_b.mean()
        return self.args.reg_coef * l


class LocConv(nn.Module): # local convolutional encoder

    def __init__(self, args, l, enc=True):
        super().__init__()
        # TODO: flatten the last layer for classification
        if enc:
            self.conv = nn.Conv2d(args.ecs[l], args.ecs[l+1],
                                kernel_size=args.eks[l], stride=args.ess[l],
                                padding=args.eps[l])
        else:
            self.conv = nn.ConvTranspose2d(args.dcs[l], args.dcs[l+1],
                                        kernel_size=args.dks[l],
                                        stride=args.dss[l], padding=args.dps[l])
        self.act = nonlinearity[args.act]
        self.skip_con = args.skip_con
        torch.nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, z):
        h = self.conv(z)
        z_hat = self.act(h)
        if self.skip_con: z_hat += z
        return z_hat, h


class ClsEmb(nn.Module):

    def __init__(self, args):
        super().__init__()
        self.emb = nn.Embedding(args.n_cls, args.ds[-1])
        torch.nn.init.kaiming_normal_(self.emb.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, y):
        return self.emb(y)


class ClsHead(nn.Module):

    def __init__(self, args):
        super().__init__()
        self.fc = nn.Linear(args.ds[-1], args.n_cls)
        # self.act = nn.ReLU()
        torch.nn.init.kaiming_normal_(self.fc.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, z):
        return self.fc(z) # self.act(self.fc(z))

class Base(nn.Module):

    def __init__(self, args):
        super().__init__()
        self.args = args
        self.L = len(args.ds) - 1
        self.update_x = args.update_x
        self.pred_enc_dec = args.pred_enc_dec
        encs = [get_enc(args, l).to('cuda') for l in range(self.L)]
        # decs = [get_dec(args, l).to('cuda') for l in (reversed(range(self.L)) \
        #     if (args.cnn and args.pc) or (args.cnn==False and args.ae) else range(self.L))]
        self.f_loc = nn.ModuleList(encs)
        # self.decs = nn.ModuleList(decs)
        self.cls = ClsHead(args) if args.cls_mth=='z2y' else None
        self.emb = ClsEmb(args) if args.cls_mth=='emb' else None
        if args.smooth_label:
            self.loss_cls_smooth = LabelSmoothingLoss(args)
            self.loss_cls = lambda prd, tg: self.loss_cls_smooth(prd, tg)
        else:
            self.loss_cls = lambda prd, tg: F.cross_entropy(prd, tg, reduction='sum')

    def loss(self, prd, tg, f, is_last_z=False):
        # print(prd.shape); print(tg.shape)
        delta = tg - prd
        n = torch.linalg.norm(delta, dim=-1)
        n_sq = torch.square(n)
        loss = n_sq.sum() if self.args.loss_sum else n_sq.mean()
        if self.args.z_reg and not is_last_z:
            alpha = get_alpha(tg)
            alpha_ = alpha.sum() if self.args.loss_sum else alpha.mean()
            loss += self.args.reg_coef * alpha_
        if self.args.w_reg or self.args.b_reg:
            loss += f.get_reg()
        return loss, delta



class AE(Base):

    def __init__(self, args):
        super().__init__(args)

    def enc(self, x):
        z = x; zs = [z]
        for l in range(self.L):
            z, _ = self.encs[l](z)
            zs.append(z.detach())
        return zs, z

    def dec(self, z):
        zs = [z]
        for l in range(self.L):
            z, _ = self.decs[l](z)
            zs.append(z.detach())
        return zs, z

    def forward(self, x):
        zs_enc, z = self.enc(x)
        zs_dec, x_hat = self.dec(z)
        return zs_enc, zs_dec, x_hat


class PC(Base):
    def __init__(self, args):
        super().__init__(args)
        self.args = args; self.iter_log = not args.not_iter_log
        self.last_cls = args.last_cls
        self.c2f = [False] + [self.args.lyr[l-1] == 'cnn' and self.args.lyr[l] == 'fc' \
                            for l in range(1, self.L)]


    def enc(self, x):
        z = x; zs = [z]
        for l in range(self.L):
            z, _ = self.encs[l](z)
            zs.append(z.detach())
        return zs

    # def dec(self, y):
    #     z = y; zs = [y]
    #     for l in reversed(range(self.L)):
    #         z, _ = self.decs[l](z)
    #         zs.append(z.detach())
    #     zs.reverse()
    #     return zs

    def feedfwd(self, x):
        z = x
        zs = list()
        for l in range(self.L):
            z, _ = self.f_loc[l](z, self.c2f[l])
            zs.append(z)
        return zs

    def predict(self, lenlog, zs, is_last, is_last_e, dec_only=False, enc_only=True, check=False,
                vl=False, log_dt=dict(), log_dt_plt=dict(), ps=None):
        """
        classifcation: zs[0], zs[-1] doesn't require grad
            i.e. (method == 'pdec' and cls_mth == 'y2z') or
                 (method == 'penc' and cls == True)
        reconstruction: zs[0] doesn't require grad
        enc_only: doesn't compare with zs[0]
        dec_only: doesn't compare with zs[-1]
        during validation: zs[-1] requires grad (self-update)
        """
        if is_last_e: self.iter_log = True
        z_encs = list(); z_decs = list(); h_encs = list()
        if (not enc_only) or vl: zs[-1].requires_grad_()
        for l in range(self.L+1):
            if l > 0 and l < self.L:
                zs[l].requires_grad_()
            if l < self.L and not dec_only:
                z_enc_hat, h_enc = self.f_loc[l](zs[l], self.c2f[l]) 
                z_encs.append(z_enc_hat)
                h_encs.append(h_enc.detach())
                if vl: zs.append(z_enc_hat.detach())
                if not check:
                    with torch.no_grad():
                        lenlog.set_zh(l, 
                            [length(zs[l].view(zs[l].size(0), -1).detach()),
                            length(z_enc_hat.view(zs[l].size(0), -1).detach()), 
                            length(h_enc.view(zs[l].size(0), -1).detach())])
        if not check:
            with torch.no_grad():
                lenlog.set_zh(self.L,[length(zs[l].detach())], True) # TODO: overwritten?
        return z_encs, z_decs, h_encs

    def compare_ls(self, lenlog, zs, is_last, z_decs, z_encs, vl, y=None, trg_dtc=False, # log_dt=dict(), log_dt_plt=dict(),
                   check=False, p_loc=1., p_io=1, ps =None, prg=None):
        loss = 0; delta_lst = list(); cnt = 0.
        for l in range(self.L+1):
            layer_loss = 0; delta = None
            if l > 0 and z_encs is not None and len(z_encs) > 0:
                if self.last_cls and l == self.L and vl == False:
                    _loss = self.loss_cls(z_encs[-1], y)
                    delta = zs[l] - z_encs[l-1]
                else:
                    if trg_dtc:
                        _loss, delta = self.loss(z_encs[l-1], zs[l].detach(), self.f_loc[l-1])
                    else:
                        _loss, delta = self.loss(z_encs[l-1], zs[l], self.f_loc[l-1])
                loss += _loss; cnt += 1
                layer_loss = _loss.item()
                delta_lst.append(delta)
                if not check:
                    lenlog.set_d(l, length(delta.view(delta.size(0), -1).detach()))
            if l < self.L and z_decs is not None and len(z_decs) > 0:
                _loss = self.loss(zs[l], z_decs[l])
                loss += _loss; cnt += 1.
                layer_loss += _loss.item()
            if not check:
                lenlog.set_loss_inf(l, layer_loss)
        loss = 0.5 * loss
        if not self.args.loss_sum: loss /=cnt
        return loss, delta_lst

    def compare(self, zs, z_decs, z_encs):
        return self.compare_ls(zs, z_decs, z_encs)[0]
