import torch
from torch import nn
import torchvision.transforms as T
import numpy as np

from koopman_layer import KoopmanLayer
from utils import t_to_np, np_to_t

# from utils.utils_disentanglement import get_sorted_indices, static_dynamic_split


class KoopmanCNN(nn.Module):

    def __init__(self, args):
        super(KoopmanCNN, self).__init__()

        self.args = args

        self.encoder = encNet(self.args)
        self.drop = torch.nn.Dropout(self.args.dropout)
        self.dynamics = KoopmanLayer(args)
        self.decoder = decNet(self.args)

        self.loss_func = nn.MSELoss()

    def forward(self, X, train=True):
        # input noise
        if train and self.args.noise in ["input"]:
            blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 3))
            X = torch.concat([torch.concat([blurrer(x).unsqueeze(0) for x in X], dim=0) for _ in range(1)])

        # ----- X.shape: b x t x c x w x h ------
        Z = self.encoder(X)
        Z2, Ct = self.dynamics(Z)
        Z = self.drop(Z)

        X_dec = self.decoder(Z)
        X_dec2 = self.decoder(Z2)

        return X_dec, X_dec2, Z, Z2, Ct

    def decode(self, Z):
        X_dec = self.decoder(Z)

        return X_dec

    def loss(self, X, outputs):
        X_dec, X_dec2, Z, Z2, Ct = outputs

        # PENALTIES
        a1 = self.args.w_rec
        a2 = self.args.w_pred
        a3 = self.args.w_pred
        a4 = self.args.w_eigs

        # reconstruction
        E1 = self.loss_func(X, X_dec)

        # Koopman losses
        E2, E3, E4 = self.dynamics.loss(X_dec, X_dec2, Z, Z2, Ct)

        # LOSS
        loss = a1 * E1 + a2 * E2 + a3 * E3 + a4 * E4

        return loss, E1, E2, E3, E4

    def swap(self, Zp, I, J, U):
        # swap J factors with shuffle I,
        import copy

        # swap certain features
        Zp_tmp = copy.deepcopy(Zp)
        Zp_tmp[:, :, J] = Zp_tmp[I][:, :, J]
        Z_tmp = Zp_tmp @ U

        self.eval()
        with torch.no_grad():
            X_tmp = self.decode(np_to_t(np.real(Z_tmp))).squeeze()

        return X_tmp

    def factorial_swap(self, classifier, X, Zp, I, J, U):
        # swap J factors with shuffle I, and eval accuracy
        get_lbl = lambda pred: np.argmax(t_to_np(pred), axis=1)
        get_acc = lambda lbl1, lbl2: np.sum(lbl1 == lbl2) / len(lbl2)

        classifier.eval()
        with torch.no_grad():
            # action, skin, pant, top, hair
            preds = classifier(X[I])

        lbls = list(map(get_lbl, preds))

        import copy

        # swap certain features
        Zp_tmp = copy.deepcopy(Zp)
        Zp_tmp[:, :, J] = Zp_tmp[I][:, :, J]
        Z_tmp = Zp_tmp @ U

        self.eval()
        with torch.no_grad():
            X_tmp = self.decode(np_to_t(np.real(Z_tmp))).squeeze()

        classifier.eval()
        with torch.no_grad():
            preds_tmp = classifier(X_tmp)
        lbls_tmp = list(map(get_lbl, preds_tmp))

        accs = list(map(get_acc, lbls_tmp, lbls))
        return accs, lbls_tmp, lbls

    # def forward_fixed_ma_for_classification(self, X, fix_motion, conj_pick=True, pick_type='norm'):
    #     # ----- X.shape: b x t x c x w x h ------
    #     Z = self.encoder(X)
    #     Z2, Ct = self.dynamics(Z)
    #     Z = self.drop(Z)
    # 
    #     Z_old_shape = Z.shape
    # 
    #     # swap a single pair in batch
    #     bsz, fsz = X.shape[0:2]
    # 
    #     # swap contents of samples in indices
    #     X = t_to_np(X)
    #     Z = t_to_np(Z.reshape(bsz, fsz, -1))
    #     C = t_to_np(Ct)
    #     swapped_Z = torch.zeros(Z.shape)
    # 
    #     # eig
    #     D, V = np.linalg.eig(C)
    #     U = np.linalg.inv(V)
    # 
    #     # static/dynamic split
    #     if pick_type == 'real':
    #         I = np.argsort(np.real(D))
    #     elif pick_type == 'norm':
    #         I = np.argsort(np.abs(D))
    #     else:
    #         raise Exception("no such method")
    # 
    #     I = get_sorted_indices(D, pick_type)
    #     Id, Is = static_dynamic_split(D, I, pick_type, self.args.static_size)
    # 
    #     for ii in range(bsz):
    #         iir = np.random.randint(bsz)
    #         while iir == ii:
    #             iir = np.random.randint(bsz)
    #         S1, Z1 = X[ii].squeeze(), Z[ii].squeeze()
    #         S2, Z2 = X[iir].squeeze(), Z[iir].squeeze()
    # 
    #         # project onto V
    #         Zp1, Zp2 = Z1 @ V, Z2 @ V
    # 
    #         # Zp* is in t x k
    #         Z1d, Z1s = Zp1[:, Id] @ U[Id], Zp1[:, Is] @ U[Is]
    #         Z2d, Z2s = Zp2[:, Id] @ U[Id], Zp2[:, Is] @ U[Is]
    # 
    #         if fix_motion:
    #             # we fix dynamics thus, use same d for our sample
    #             swapped_Z[ii] = torch.from_numpy(np.real(Z1d + Z2s)).to(self.args.device)
    #         else:
    #             swapped_Z[ii] = torch.from_numpy(np.real(Z2d + Z1s)).to(self.args.device)
    # 
    #     ZNs = torch.from_numpy(Z).to(self.args.device)
    #     Z = swapped_Z.to(self.args.device)
    # 
    #     X_dec_sample = self.decoder(Z.reshape(Z_old_shape))
    #     X_dec = self.decoder(ZNs.reshape(Z_old_shape))
    # 
    #     return X_dec_sample, X_dec


class conv(nn.Module):
    def __init__(self, nin, nout):
        super(conv, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(nin, nout, 4, 2, 1),
            nn.BatchNorm2d(nout),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, input):
        return self.net(input)


class upconv(nn.Module):
    def __init__(self, nin, nout):
        super(upconv, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(nin, nout, 4, 2, 1),
            nn.BatchNorm2d(nout),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, input):
        return self.net(input)


class encNet(nn.Module):

    def __init__(self, args):
        super(encNet, self).__init__()

        self.args = args

        self.n_frames = args.n_frames
        self.n_channels = args.n_channels
        self.n_height = args.n_height
        self.n_width = args.n_width
        self.conv_dim = args.conv_dim
        self.k_dim = args.k_dim
        self.hidden_dim = args.hidden_dim

        self.c1 = conv(self.n_channels, self.conv_dim)
        self.c2 = conv(self.conv_dim, self.conv_dim * 2)
        self.c3 = conv(self.conv_dim * 2, self.conv_dim * 4)
        self.c4 = conv(self.conv_dim * 4, self.conv_dim * 8)
        self.c5 = nn.Sequential(
            nn.Conv2d(self.conv_dim * 8, self.k_dim, 4, 1, 0),
            nn.BatchNorm2d(self.k_dim),
            nn.Tanh()
        )

        if args.rnn in ["enc", "both"]:
            self.lstm = nn.LSTM(self.k_dim, self.hidden_dim, batch_first=True, bias=True,
                                bidirectional=False)

    def forward(self, x):
        x = x.reshape(-1, self.n_channels, self.n_height, self.n_width)
        h1 = self.c1(x)
        h2 = self.c2(h1)
        h3 = self.c3(h2)
        h4 = self.c4(h3)
        h5 = self.c5(h4)

        # lstm
        if self.args.rnn in ["enc", "both"]:
            h5 = self.lstm(h5.reshape(-1, self.n_frames, self.k_dim))[0].reshape(-1, self.hidden_dim, 1, 1)

        return h5


class decNet(nn.Module):
    def __init__(self, args):
        super(decNet, self).__init__()

        self.args = args

        self.n_frames = args.n_frames
        self.n_channels = args.n_channels
        self.n_height = args.n_height
        self.n_width = args.n_width
        self.conv_dim = args.conv_dim
        self.koopman_dim = args.k_dim
        self.lstm_hidden_size = args.hidden_dim

        self.upc1 = nn.Sequential(
            nn.ConvTranspose2d(self.koopman_dim, self.conv_dim * 8, 4, 1, 0),
            nn.BatchNorm2d(self.conv_dim * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.upc2 = upconv(self.conv_dim * 8, self.conv_dim * 4)
        self.upc3 = upconv(self.conv_dim * 4, self.conv_dim * 2)
        self.upc4 = upconv(self.conv_dim * 2, self.conv_dim)
        self.upc5 = nn.Sequential(
            nn.ConvTranspose2d(self.conv_dim, self.n_channels, 4, 2, 1),
            nn.Sigmoid()
        )

        if args.rnn in ["dec", "both"]:
            self.lstm = nn.LSTM(self.lstm_hidden_size, self.koopman_dim, batch_first=True, bias=True)

    def forward(self, x):
        # lstm
        if self.args.rnn in ["dec", "both"]:
            x = self.lstm(x.reshape(-1, self.n_frames, self.lstm_hidden_size))[0].reshape(-1, self.koopman_dim, 1, 1)

        d1 = self.upc1(x)
        d2 = self.upc2(d1)
        d3 = self.upc3(d2)
        d4 = self.upc4(d3)
        output = self.upc5(d4)
        output = output.view(-1, self.n_frames, self.n_channels, self.n_height, self.n_width)

        return output
