import torch
import torch.optim as optim
from utils.ot_functions import sim2real, reverse_val
import ot
import numpy as np
from dataloader.dataloader import DataLoader
from sklearn.neighbors import KNeighborsClassifier
import scipy as sc
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from network.train_functions_global import standard_scaler
from sklearn.model_selection import train_test_split

class CoupledAE(torch.nn.Module):
    def __init__(
            self,
            model,
            aeconfig,
            encoder_1,
            encoder_2,
            decoder_1,
            decoder_2,
            classifier):
        super(CoupledAE, self).__init__()
        self.vc = aeconfig  # This is later used to access the experiment specific loss
        # and housekeeping functions specified in the network/train_functions.py
        # of each experiment
        self.model = model
        self.encoder_1 = encoder_1
        self.decoder_1 = decoder_1
        self.encoder_2 = encoder_2
        self.decoder_2 = decoder_2
        self.classifier = classifier

        self.loss_function_ae = self.vc.loss_function_ae
        self.classification_loss = torch.nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(
            self.parameters(),
            lr=self.model.training_params['learning_rate'])
        self.scheduler = optim.lr_scheduler.ExponentialLR(
            self.optimizer, gamma=.95)

        self.get_data = self.vc.get_data
        self.get_z = self.vc.get_z
        self.get_l_probs = self.vc.get_l_probs

        self.apply(self.init_weights)

    def init_weights(self, m):

        if isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)

    def forward(self, o_in, ae_idx):
        z = self.encode(o_in, ae_idx)
        o_pred = self.decode(z, ae_idx)
        return o_pred, z

    def encode(self, o_in, ae_idx):
        if ae_idx == 0:
            z = self.encoder_1(o_in)
        else:
            z = self.encoder_2(o_in)
        return z

    def decode(self, z, ae_idx):
        if ae_idx == 0:
            o_pred = self.decoder_1(z)
        else:
            o_pred = self.decoder_2(z)
        return o_pred

    def save_model(self, epoch):
        # Set model file
        self.model.model_file = self.state_dict()
        self.model.epoch = epoch
        # Update pkl file
        self.model.save_model()

    def load_model(self):
        self.load_state_dict(self.model.model_file)

    def train(self):
        # Setup dataloaders
        print("Starting training")
        print("Setup dataloaders...")

        train_loader_1, train_dataset_1 = DataLoader(
            self.model.training_params, self.model.dataset_params_train_1, tag='source').get_loader()
        train_loader_2, train_dataset_2 = DataLoader(
            self.model.training_params, self.model.dataset_params_train_2, tag='target').get_loader()

        if abs(len(train_dataset_1) - len(train_dataset_2)) > 4 * len(train_dataset_1) or \
                abs(len(train_dataset_1) - len(train_dataset_2)) > 4 * len(train_dataset_2):
            print("network/doubleae.py@76: Warning, Datasets are imbalanced!")

        main_loader = train_loader_1
        main_dataset = train_dataset_1
        main_dataset_params = self.model.dataset_params_train_1
        main_ae = 0
        secondary_loader = train_loader_2
        secondary_dataset = train_dataset_2
        secondary_dataset_params = self.model.dataset_params_train_2
        secondary_ae = 1
        flipFlag = False

        if len(train_dataset_2) > len(train_dataset_1):
            main_loader = train_loader_2
            main_dataset = train_dataset_2
            main_dataset_params = self.model.dataset_params_train_2
            main_ae = 1
            secondary_loader = train_loader_1
            secondary_dataset = train_dataset_1
            secondary_dataset_params = self.model.dataset_params_train_1
            secondary_ae = 0
            flipFlag = True

        print("Dataloaders up and running!")

        if self.model.epoch + \
                1 >= self.model.training_params['train_epochs'] + 1:
            print(
                "Model cannot continue, set a higher train epochs variable than load epoch")
            print("Training was not resumed, exiting ...")
        else:
            if flipFlag:
                self.test(
                    secondary_dataset_params,
                    main_dataset_params,
                    secondary_ae,
                    main_ae, 0)
            else:
                self.test(
                    main_dataset_params,
                    secondary_dataset_params,
                    main_ae,
                    secondary_ae, 0)

            if self.model.training_params['invariant']:
                map = (torch.eye(self.model.training_params['latent_dim']), 0)
            else:
                map = None

            for epoch in range(self.model.epoch + 1,
                               self.model.training_params['train_epochs'] + 1):

                epoch_loss = 0
                classif_loss = 0

                it = iter(secondary_loader)
                for i, main_data in enumerate(main_loader):
                    # Since the main loader contains usually more batches than the secondary loader,
                    # we have to start the secondary loader over again if it
                    # runs out of items to iterate over
                    try:
                        secondary_data = next(it)
                    except StopIteration:
                        it = iter(secondary_loader)
                        secondary_data = next(it)

                    # Obtaining data from dataloader
                    if flipFlag:
                        d1, l1, idxS = self.get_data(secondary_data)
                        d2, l2, idxT = self.get_data(main_data)
                    else:
                        d1, l1, idxS = self.get_data(main_data)
                        d2, l2, idxT = self.get_data(secondary_data)

                    if main_dataset.features.shape[1] == secondary_dataset.features.shape[1]:
                        origC = ot.dist(
                            train_dataset_1.features[idxS],
                            train_dataset_2.features[idxT])

                    elif d1.shape[0] == d2.shape[0]:
                        Cx = sc.spatial.distance.cdist(
                            train_dataset_1.features[idxS].detach().cpu().numpy(),
                            train_dataset_1.features[idxS].detach().cpu().numpy(),
                            metric='cosine')

                        Cy = sc.spatial.distance.cdist(
                            train_dataset_2.features[idxT].detach().cpu().numpy(),
                            train_dataset_2.features[idxT].detach().cpu().numpy(),
                            metric='cosine')

                        origC = torch.from_numpy(ot.dist(Cx, Cy)).to(
                            device=self.model.training_params['device'])

                    else:
                        origC = None

                    if train_dataset_2.idx_supervision is not None:
                        idxToUse = [idx for idx, x in enumerate(
                            idxT) if x in train_dataset_2.idx_supervision]
                    else:
                        idxToUse = []

                    # Computing latent and next state predictions
                    if flipFlag:
                        z1 = self.get_z(d1, self, secondary_ae)
                        z2 = self.get_z(d2, self, main_ae)
                    else:
                        z1 = self.get_z(d1, self, main_ae)
                        z2 = self.get_z(d2, self, secondary_ae)

                    if self.model.training_params['experiment'] == 'nn_features':
                        scale = True
                    else:
                        scale = False

                    loss_rec, hk_data, A, b = self.loss_function_ae(
                        self.model, z1, z2, l1, l2, idxToUse, C=origC, map=map, scale=scale,
                        use_target=self.model.training_params['use_target'],
                        device=self.model.training_params['device'])

                    if self.model.training_params['supervision']:
                        sample = standard_scaler(z1[2])
                        labels = l1
                        l_probs = self.get_l_probs(sample, self.classifier)
                        loss_classification = self.classification_loss(
                            l_probs, labels)
                        loss = loss_rec + loss_classification
                        classif_loss += loss_classification.item()

                    else:
                        loss = loss_rec

                    # Update Gradients
                    self.optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.parameters(), 2.)

                    self.optimizer.step()
                    epoch_loss += loss.item()

                    self.model.housekeeping.add(
                        'batch_train_loss_rec_1', hk_data[0])
                    self.model.housekeeping.add(
                        'batch_train_loss_rec_2', hk_data[1])
                    self.model.housekeeping.add(
                        'batch_train_loss_affinity', hk_data[2])

                epoch_loss = epoch_loss / len(main_loader)

                if self.model.training_params['verbose']:
                    print(
                        'Epoch {} done, epoch loss: {}'.format(
                            epoch, epoch_loss))

                # Save housekeeping
                self.model.housekeeping.add('epoch_train_loss', epoch_loss)
                # This has to be last because it is also saving all the other
                # things set to pkl
                self.save_model(epoch)

                if flipFlag:
                    self.test(
                        secondary_dataset_params,
                        main_dataset_params,
                        secondary_ae,
                        main_ae, epoch)
                else:
                    self.test(
                        main_dataset_params,
                        secondary_dataset_params,
                        main_ae,
                        secondary_ae, epoch)

    def test(
            self,
            main_dataset_params,
            secondary_dataset_params,
            main_ae,
            secondary_ae, epoch):

        # We need to set up the dataloaders here without the skipping of the
        # last uneven batch
        self.model.dataset_params_train_1['drop_last'] = False
        self.model.dataset_params_train_2['drop_last'] = False
        test_loader_1, test_dataset_1 = DataLoader(
            self.model.training_params, main_dataset_params, shuffle=False).get_loader()
        test_loader_2, test_dataset_2 = DataLoader(
            self.model.training_params, secondary_dataset_params, shuffle=False).get_loader()

        self.model.dataset_params_train_1['drop_last'] = True
        self.model.dataset_params_train_2['drop_last'] = True

        # Go through both datasets and record zs
        z_main_store = np.zeros(
            (len(test_dataset_1),
             self.model.training_params['latent_dim'])).astype(
            np.float32)
        z_secondary_store = np.zeros(
            (len(test_dataset_2), self.model.training_params['latent_dim'])).astype(
            np.float32)

        l_main_store = np.zeros((len(test_dataset_1),))
        l_secondary_store = np.zeros((len(test_dataset_2),))

        with torch.no_grad():
            for i, main_data in enumerate(test_loader_1):
                d = self.get_data(main_data)
                z = self.get_z(d[0], self, main_ae)

                # Store latent representations
                z_main_store[i *
                             main_dataset_params['batch_size']:(i *
                                                                main_dataset_params['batch_size'] +
                                                                z[0].shape[0])] = z[2].detach().cpu().numpy()

                if d[1].detach().cpu().numpy().ndim > 1:
                    l_main_store[i *
                                 main_dataset_params['batch_size']:(i *
                                                                    main_dataset_params['batch_size'] +
                                                                    d[1].shape[0])] = np.argmax(d[1].detach().cpu().numpy(), -
                                                                                                1)
                else:
                    l_main_store[i *
                                 main_dataset_params['batch_size']:(i *
                                                                    main_dataset_params['batch_size'] +
                                                                    d[1].shape[0])] = d[1].detach().cpu().numpy()

            for i, secondary_data in enumerate(test_loader_2):
                d = self.get_data(secondary_data)
                z = self.get_z(d[0], self, secondary_ae)

                # Store latent representations
                z_secondary_store[i *
                                  secondary_dataset_params['batch_size']:(i *
                                                                          secondary_dataset_params['batch_size'] +
                                                                          z[0].shape[0])] = z[2].detach().cpu().numpy()

                if d[1].detach().cpu().numpy().ndim > 1:
                    l_secondary_store[i *
                                      secondary_dataset_params['batch_size']:(i *
                                                                              secondary_dataset_params['batch_size'] +
                                                                              d[1].shape[0])] = np.argmax(d[1].detach().cpu().numpy(), -
                                                                                                          1)
                else:
                    l_secondary_store[i *
                                      secondary_dataset_params['batch_size']:(i *
                                                                              secondary_dataset_params['batch_size'] +
                                                                              d[1].shape[0])] = d[1].detach().cpu().numpy()

            self.model.housekeeping.add(
                'z_{}'.format(main_ae + 1), z_main_store)
            self.model.housekeeping.add(
                'z_{}'.format(
                    secondary_ae + 1),
                z_secondary_store)

            self.model.housekeeping.add(
                'l_{}'.format(main_ae + 1), l_main_store)
            self.model.housekeeping.add(
                'l_{}'.format(
                    secondary_ae + 1),
                l_secondary_store)

            z_1_train, z_1_test, labels_1_train, labels_1_test = train_test_split(
                z_main_store, l_main_store, shuffle=True, test_size=.3)
            z_2_train, z_2_test, labels_2_train, labels_2_test = train_test_split(
                z_secondary_store, l_secondary_store, shuffle=True, test_size=.3)

            # training KNN classifier from z_1 ->
            # labels_1
            clf = KNeighborsClassifier(n_neighbors=1)
            clf.fit(z_1_train, labels_1_train)
            print(
                "z_1->labels_1: {}".format(clf.score(z_1_test, labels_1_test)))

            # training KNN classifier from z_2 ->
            # labels_2
            clf = KNeighborsClassifier(n_neighbors=1)
            clf.fit(z_2_train, labels_2_train)
            print(
                "z_2->labels_2: {}".format(clf.score(z_2_test, labels_2_test)))

            if self.model.training_params['experiment'] == 'nn_features':
                z1T = standard_scaler(
                    torch.from_numpy(z_main_store).to(
                        device=self.model.training_params['device']))
                z2T = standard_scaler(
                    torch.from_numpy(z_secondary_store).to(
                        device=self.model.training_params['device']))
            else:
                z1T = torch.from_numpy(z_main_store).to(
                    device=self.model.training_params['device'])
                z2T = torch.from_numpy(z_secondary_store).to(
                    device=self.model.training_params['device'])

            if self.model.training_params['invariant']:
                sourceAligned = z1T
            else:
                sourceAligned, _, A, b = sim2real(
                    z1T, z2T, return_map=True,
                    do_procrustes=False, device=self.model.training_params['device'])

            if self.model.training_params['experiment'] == 'nn_features':

                self.model.housekeeping.add(
                    'l_{}'.format(main_ae + 1), l_main_store)
                self.model.housekeeping.add(
                    'l_{}'.format(
                        secondary_ae + 1),
                    l_secondary_store)

                if self.model.training_params['supervision']:
                    l_probs = self.get_l_probs(
                        z2T,
                        self.classifier)
                    yt_estimated = torch.argmax(l_probs, dim=1)
                    acc = torch.sum(
                        torch.from_numpy(l_secondary_store).to(
                            device=self.model.training_params['device']) == yt_estimated) / len(
                        torch.from_numpy(l_secondary_store).to(
                            device=self.model.training_params['device']))
                    if self.model.training_params['verbose']:
                        print('Transfer score (NN): {:.2f}'.format(acc * 100))
                    self.model.housekeeping.add('acc_nn', acc * 100)

                clf_1 = KNeighborsClassifier(n_neighbors=1)
                clf_3 = KNeighborsClassifier(n_neighbors=3)

                param_grid = {'kernel': ['linear'], 'C': [.001, .01, .1]}
                clf_SVC = GridSearchCV(
                    SVC(), param_grid, scoring='accuracy', n_jobs=-1)

                clf_1.fit(sourceAligned.detach().cpu().numpy(), l_main_store)
                acc = clf_1.score(
                    z2T, l_secondary_store)

                self.model.housekeeping.add('acc_knn_1', acc * 100)

                if self.model.training_params['verbose']:
                    print("Transfer score (1NN): {:.2f}".format(acc * 100))

                clf_3.fit(sourceAligned.detach().cpu().numpy(), l_main_store)
                acc = clf_3.score(
                    z2T, l_secondary_store)

                if self.model.training_params['verbose']:
                    print("Transfer score (3NN): {:.2f}".format(acc * 100))

                self.model.housekeeping.add('acc_knn_3', acc * 100)

                clf_SVC.fit(sourceAligned.detach().cpu().numpy(), l_main_store)
                acc = clf_SVC.score(
                    z2T, l_secondary_store)

                self.model.housekeeping.add('acc_svc_ot', acc * 100)

                if self.model.training_params['verbose']:
                    print("Transfer score SVC: {:.2f}".format(acc * 100))

                if self.model.training_params['reverse'] and epoch > 0:

                    self.model.housekeeping.add(
                        'rev_1nn_ot',
                        reverse_val(
                            sourceAligned,
                            l_main_store,
                            z2T,
                            l_secondary_store,
                            test_dataset_2.idx_supervision,
                            clf_1))

                    self.model.housekeeping.add(
                        'rev_3nn_ot',
                        reverse_val(
                            sourceAligned,
                            l_main_store,
                            z2T,
                            l_secondary_store,
                            test_dataset_2.idx_supervision,
                            clf_3))
                    self.model.housekeeping.add(
                        'rev_svc_ot',
                        reverse_val(
                            sourceAligned,
                            l_main_store,
                            z2T,
                            l_secondary_store,
                            test_dataset_2.idx_supervision,
                            clf_SVC))
