import torch
import torch.distributions as dist
import numpy as np
from torch import nn
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from scipy import stats
from statistics import mean
import shutil
import os
import glob
from fid.inception import InceptionV3
from fid.fid_score import get_activations
from fid.fid_score import calculate_frechet_distance


def train_clf_lr(dl, model, missing, device, n=10):
    model.eval()

    latent_rep = {'svhn': {'us': [], 'zs': [], 'ws': []},
                  'mnist': {'us': [], 'zs': [], 'ws': []}}
    labels_all = []
    if 'sup' in dl.keys():
        for i, (svhn, mnist, y) in enumerate(dl['sup']):
            b_size = y.size(0)
            labels_batch = nn.functional.one_hot(y.long().to(device), num_classes=n).float()
            labels = labels_batch.cpu().data.numpy().reshape(b_size, n);
            labels_all.append(labels)
            with torch.no_grad():
                au = dist.Normal(*model.a_to_z(svhn.to(device))).sample()
                az, aw = torch.split(au, [model.z_dim, model.w_dim], dim=-1)
                latent_rep['svhn']['us'].append(au.cpu().data.numpy())
                latent_rep['svhn']['zs'].append(az.cpu().data.numpy())
                latent_rep['svhn']['ws'].append(aw.cpu().data.numpy())
                bu = dist.Normal(*model.b_to_z(mnist.to(device))).sample()
                bz, bw = torch.split(bu, [model.z_dim, model.w_dim], dim=-1)
                latent_rep['mnist']['us'].append(bu.cpu().data.numpy())
                latent_rep['mnist']['zs'].append(bz.cpu().data.numpy())
                latent_rep['mnist']['ws'].append(bw.cpu().data.numpy())
    if 'unsup' in dl.keys():
        if missing=='svhn':
            for i, (svhn, mnist, y) in enumerate(dl['unsup']):
                b_size = y.size(0)
                labels_batch = nn.functional.one_hot(y.long().to(device), num_classes=n).float()
                labels = labels_batch.cpu().data.numpy().reshape(b_size, n);
                labels_all.append(labels)
                with torch.no_grad():
                    bu = dist.Normal(*model.b_to_z(mnist.to(device))).sample()
                    bz, bw = torch.split(bu, [model.z_dim, model.w_dim], dim=-1)
                    latent_rep['mnist']['us'].append(bu.cpu().data.numpy())
                    latent_rep['mnist']['zs'].append(bz.cpu().data.numpy())
                    latent_rep['mnist']['ws'].append(bw.cpu().data.numpy())
        elif missing=='mnist':
            for i, (svhn, mnist, y) in enumerate(dl['unsup']):
                b_size = y.size(0)
                labels_batch = nn.functional.one_hot(y.long().to(device), num_classes=n).float()
                labels = labels_batch.cpu().data.numpy().reshape(b_size, n);
                labels_all.append(labels)
                with torch.no_grad():
                    au = dist.Normal(*model.a_to_z(svhn.to(device))).sample()
                    az, aw = torch.split(au, [model.z_dim, model.w_dim], dim=-1)
                    latent_rep['svhn']['us'].append(au.cpu().data.numpy())
                    latent_rep['svhn']['zs'].append(az.cpu().data.numpy())
                    latent_rep['svhn']['ws'].append(aw.cpu().data.numpy())
    labels_all = np.concatenate(labels_all, axis=0)
    gt = np.argmax(labels_all, axis=1).astype(int)
    clf_lr = dict();
    latent_rep_au = np.concatenate(latent_rep['svhn']['us'], axis=0)
    latent_rep_az = np.concatenate(latent_rep['svhn']['zs'], axis=0)
    latent_rep_aw = np.concatenate(latent_rep['svhn']['ws'], axis=0)
    clf_lr_rep_au = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=1000)
    clf_lr_rep_az = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=1000)
    clf_lr_rep_aw = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=1000)
    clf_lr_rep_au.fit(latent_rep_au, gt.ravel()[:latent_rep_au.shape[0]])
    clf_lr['svhn_u'] = clf_lr_rep_au
    clf_lr_rep_az.fit(latent_rep_az, gt.ravel()[:latent_rep_au.shape[0]])
    clf_lr['svhn_z'] = clf_lr_rep_az
    clf_lr_rep_aw.fit(latent_rep_aw, gt.ravel()[:latent_rep_au.shape[0]])
    clf_lr['svhn_w'] = clf_lr_rep_aw

    latent_rep_bu = np.concatenate(latent_rep['mnist']['us'], axis=0)
    latent_rep_bz = np.concatenate(latent_rep['mnist']['zs'], axis=0)
    latent_rep_bw = np.concatenate(latent_rep['mnist']['ws'], axis=0)
    clf_lr_rep_bu = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=1000)
    clf_lr_rep_bz = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=1000)
    clf_lr_rep_bw = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=1000)
    clf_lr_rep_bu.fit(latent_rep_bu, gt.ravel()[:latent_rep_bu.shape[0]])
    clf_lr['mnist_u'] = clf_lr_rep_bu
    clf_lr_rep_bz.fit(latent_rep_bz, gt.ravel()[:latent_rep_bu.shape[0]])
    clf_lr['mnist_z'] = clf_lr_rep_bz
    clf_lr_rep_bw.fit(latent_rep_bw, gt.ravel()[:latent_rep_bu.shape[0]])
    clf_lr['mnist_w'] = clf_lr_rep_bw
    return clf_lr

def classify_latent_representations(clf_lr, data, labels, split=False):
    gt = np.argmax(labels, axis=1).astype(int)
    accuracies = dict()

    y_preds_z = []
    if split:
        data_rep_au, data_rep_az, data_rep_aw = data[0]
        data_rep_bu, data_rep_bz, data_rep_bw = data[1]
    else:
        data_rep_au = data[0][0]
        data_rep_bu = data[1][0]

    clf_lr_rep_au = clf_lr['svhn_u'];
    y_pred_rep_au = clf_lr_rep_au.predict(data_rep_au);
    accuracy_rep_au = accuracy_score(gt, y_pred_rep_au.ravel());
    accuracies['svhn_u'] = accuracy_rep_au;

    clf_lr_rep_bu = clf_lr['mnist_u'];
    y_pred_rep_bu = clf_lr_rep_bu.predict(data_rep_bu);
    accuracy_rep_bu = accuracy_score(gt, y_pred_rep_bu.ravel());
    accuracies['mnist_u'] = accuracy_rep_bu;

    if split:
        clf_lr_rep_az = clf_lr['svhn_z'];
        y_pred_rep_az = clf_lr_rep_az.predict(data_rep_az);
        accuracy_rep_az = accuracy_score(gt, y_pred_rep_az.ravel());
        accuracies['svhn_z'] = accuracy_rep_az;

        clf_lr_rep_aw = clf_lr['svhn_w'];
        y_pred_rep_aw = clf_lr_rep_aw.predict(data_rep_aw);
        accuracy_rep_aw = accuracy_score(gt, y_pred_rep_aw.ravel());
        accuracies['svhn_w'] = accuracy_rep_aw;

        y_preds_z.append(y_pred_rep_az)

        clf_lr_rep_bz = clf_lr['mnist_z'];
        y_pred_rep_bz = clf_lr_rep_bz.predict(data_rep_bz);
        accuracy_rep_bz = accuracy_score(gt, y_pred_rep_bz.ravel());
        accuracies['mnist_z'] = accuracy_rep_bz;

        clf_lr_rep_bw = clf_lr['mnist_w'];
        y_pred_rep_bw = clf_lr_rep_bw.predict(data_rep_bw);
        accuracy_rep_bw = accuracy_score(gt, y_pred_rep_bw.ravel());
        accuracies['mnist_w'] = accuracy_rep_bw;

        y_preds_z.append(y_pred_rep_bz)

    overall_preds = stats.mode(np.stack(y_preds_z))[0]
    accuracy_rep_all = accuracy_score(gt, overall_preds.ravel());
    accuracies['all'] = accuracy_rep_all;

    return accuracies;

def linear_latent_classification(dl, model, clf_lr, device, n=10):
    accuracies_lr = {}
    all_au, all_az, all_aw = [], [], []
    all_bu, all_bz, all_bw = [], [], []
    all_labels = []
    latent_reps = []
    with torch.no_grad():
        for i, (svhn, mnist, y) in enumerate(dl):
            svhn, mnist, y = svhn.to(device), mnist.to(device), y.to(device)
            b_size = y.size(0)
            labels_batch = nn.functional.one_hot(y.long(), num_classes=n).float()
            labels = labels_batch.cpu().data.numpy().reshape(b_size, n)
            
            au = dist.Normal(*model.a_to_z(svhn.to(device))).sample()
            az, aw = torch.split(au, [model.z_dim, model.w_dim], dim=-1)
            all_au.append(au.cpu().data.numpy())
            all_az.append(az.cpu().data.numpy())
            all_aw.append(aw.cpu().data.numpy())
            bu = dist.Normal(*model.b_to_z(mnist.to(device))).sample()
            bz, bw = torch.split(bu, [model.z_dim, model.w_dim], dim=-1)
            all_bu.append(bu.cpu().data.numpy())
            all_bz.append(bz.cpu().data.numpy())
            all_bw.append(bw.cpu().data.numpy())
            all_labels.append(labels)

        latent_reps.append([np.concatenate(all_au, axis=0), np.concatenate(all_az, axis=0), np.concatenate(all_aw, axis=0)])
        latent_reps.append([np.concatenate(all_bu, axis=0), np.concatenate(all_bz, axis=0), np.concatenate(all_bw, axis=0)])
        all_labels = np.concatenate(all_labels, axis=0)
        
        accuracies = classify_latent_representations(clf_lr, latent_reps, all_labels, split=True)

        accuracies_lr['svhn_u'] = accuracies['svhn_u']
        accuracies_lr['mnist_u'] = accuracies['mnist_u']
        accuracies_lr['svhn_z'] = accuracies['svhn_z']
        accuracies_lr['mnist_z'] = accuracies['mnist_z']
        accuracies_lr['svhn_w'] = accuracies['svhn_w']
        accuracies_lr['mnist_w'] = accuracies['mnist_w']
        accuracies_lr['mean_u'] = mean([accuracies_lr['svhn_u'], accuracies_lr['mnist_u']])
        accuracies_lr['mean_z'] = mean([accuracies_lr['svhn_z'], accuracies_lr['mnist_z']])
        accuracies_lr['mean_w'] = mean([accuracies_lr['svhn_w'], accuracies_lr['mnist_w']])
        accuracies_lr['z_all'] = accuracies['all']

    return accuracies_lr


def cross_coherence(clfs, dl, model, device):
    for clf in clfs:
        clf.eval()
        clf.to(device)
    
    corrs = [[0, 0],
             [0, 0]]
    total = 0
    with torch.no_grad():
        for _, (svhn, mnist, y) in enumerate(dl):
            svhn, mnist, y = svhn.to(device), mnist.to(device), y.to(device)
            total += y.size(0)

            px_us = [[None for _ in range(2)] for _ in range(2)]
            px_us[0][0] = model.a_to_a(svhn)
            px_us[0][1] = F.pad(model.a_to_b(svhn).view(-1, 1, 28, 28), (2, 2, 2, 2),
                                mode='constant', value=0).expand(-1, 3, -1, -1)
            px_us[1][0] = model.b_to_a(mnist)
            px_us[1][1] = F.pad(model.b_to_b(mnist).view(-1, 1, 28, 28), (2, 2, 2, 2),
                                mode='constant', value=0).expand(-1, 3, -1, -1)
            
            for idx_srt in range(2):
                for idx_trg in range(2):
                    clfs_results = torch.argmax(clfs[idx_trg](px_us[idx_srt][idx_trg]), dim=-1)
                    corrs[idx_srt][idx_trg] += (clfs_results == y).sum().item()

        for idx_trgt in range(2):
            for idx_strt in range(2):
                corrs[idx_strt][idx_trgt] = corrs[idx_strt][idx_trgt] / total
                
    return corrs


def calculate_inception_features_for_gen_evaluation(inception_state_dict_path, device, dir_fid_base, datadir, dims=2048, batch_size=128):
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx], path_state_dict=inception_state_dict_path)
    model = model.to(device)

    for moddality in ['svhn', 'mnist']:
        filename_act_real_calc = os.path.join(datadir, 'real_activations_{}.npy'.format(moddality))
        if not os.path.exists(filename_act_real_calc):
            files_real_calc = glob.glob(os.path.join(datadir, moddality, '*' + '.png'))
            act_real_calc = get_activations(files_real_calc, model, device, batch_size, dims, verbose=False)
            np.save(filename_act_real_calc, act_real_calc)

    for prefix  in ['svhn', 'mnist']:
        dir_gen = os.path.join(dir_fid_base, prefix)
        if not os.path.exists(dir_gen):
            raise RuntimeError('Invalid path: %s' % dir_gen)
        for modality in ['svhn', 'mnist']:
            files_gen = glob.glob(os.path.join(dir_gen, modality, '*' + '.png'))
            filename_act = os.path.join(dir_gen, modality + '_activations.npy')
            act_rand_gen = get_activations(files_gen, model, device, batch_size, dims, verbose=False)
            np.save(filename_act, act_rand_gen)

def calculate_fid(feats_real, feats_gen):
    mu_real = np.mean(feats_real, axis=0)
    sigma_real = np.cov(feats_real, rowvar=False)
    mu_gen = np.mean(feats_gen, axis=0)
    sigma_gen = np.cov(feats_gen, rowvar=False)
    fid = calculate_frechet_distance(mu_real, sigma_real, mu_gen, sigma_gen)
    return fid;

def calculate_fid_routine(dl, model, datadir, fid_path, inception_path, num_fid_samples, device):
    total_cond = 0
    for j in ['svhn', 'mnist']:
        for i in ['svhn', 'mnist']:
            if os.path.exists(os.path.join(fid_path, j, i)):
                shutil.rmtree(os.path.join(fid_path, j, i))
                os.makedirs(os.path.join(fid_path, j, i))
            else:
                os.makedirs(os.path.join(fid_path, j, i))
    with torch.no_grad():
        for i, data in enumerate(dl):
            if total_cond < num_fid_samples:
                model.self_and_cross_modal_generation_for_fid_calculation(data, datadir, fid_path, i)
                total_cond += data[0].size(0)
        calculate_inception_features_for_gen_evaluation(inception_path, device, fid_path, datadir)

        fid_condgen_list = []
        for modality_target in ['svhn', 'mnist']:
            file_activations_real = os.path.join(datadir, 'real_activations_{}.npy'.format(modality_target))
            feats_real = np.load(file_activations_real)
            
            fid_condgen_target_list = []
            for modality_source in ['svhn', 'mnist']:
                file_activations_gen = os.path.join(fid_path, modality_source, modality_target + '_activations.npy')
                feats_gen = np.load(file_activations_gen)
                fid_val = calculate_fid(feats_real, feats_gen)
                print("FID/{}/{}: ".format(modality_source, modality_target) + str(fid_val))
                fid_condgen_target_list.append(fid_val)
            fid_condgen_list.append(mean(fid_condgen_target_list))
        mean_fid_condgen = mean(fid_condgen_list)
        print("FID/condgen_meanll: " + str(mean_fid_condgen))
        
    if os.path.exists(fid_path):
        shutil.rmtree(fid_path)
        os.makedirs(fid_path)