import torch
import torch.distributions as dist
import numpy as np
from torch import nn
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from scipy import stats
from statistics import mean
import shutil
import os
import glob
from fid.inception import InceptionV3
from fid.fid_score import get_activations
from fid.fid_score import calculate_frechet_distance


def train_clf_lr(dl, model, missing, device, n=8):
    model.eval()

    latent_rep = {'image': {'us': [], 'zs': [], 'ws': []},
                  'sentence': {'us': [], 'zs': [], 'ws': []}}
    labels_all = []
    if 'sup' in dl.keys():
        for _, (image, sentence, y) in enumerate(dl['sup']):
            b_size = y.size(0)
            labels_batch = nn.functional.one_hot(y.long().to(device), num_classes=n).float()
            labels = labels_batch.cpu().data.numpy().reshape(b_size, n);
            labels_all.append(labels)
            with torch.no_grad():
                au = dist.Normal(*model.a_to_z(image.to(device))).sample()
                az, aw = torch.split(au, [model.z_dim, model.w_dim], dim=-1)
                latent_rep['image']['us'].append(au.cpu().data.numpy())
                latent_rep['image']['zs'].append(az.cpu().data.numpy())
                latent_rep['image']['ws'].append(aw.cpu().data.numpy())
                bu = dist.Normal(*model.b_to_z(sentence.to(device))).sample()
                bz, bw = torch.split(bu, [model.z_dim, model.w_dim], dim=-1)
                latent_rep['sentence']['us'].append(bu.cpu().data.numpy())
                latent_rep['sentence']['zs'].append(bz.cpu().data.numpy())
                latent_rep['sentence']['ws'].append(bw.cpu().data.numpy())
    if 'unsup' in dl.keys():
        if missing=='image':
            for _, (image, sentence, y) in enumerate(dl['unsup']):
                b_size = y.size(0)
                labels_batch = nn.functional.one_hot(y.long().to(device), num_classes=n).float()
                labels = labels_batch.cpu().data.numpy().reshape(b_size, n);
                labels_all.append(labels)
                with torch.no_grad():
                    bu = dist.Normal(*model.b_to_z(sentence.to(device))).sample()
                    bz, bw = torch.split(bu, [model.z_dim, model.w_dim], dim=-1)
                    latent_rep['sentence']['us'].append(bu.cpu().data.numpy())
                    latent_rep['sentence']['zs'].append(bz.cpu().data.numpy())
                    latent_rep['sentence']['ws'].append(bw.cpu().data.numpy())
        elif missing=='sentence':
            for _, (image, sentence, y) in enumerate(dl['unsup']):
                b_size = y.size(0)
                labels_batch = nn.functional.one_hot(y.long().to(device), num_classes=n).float()
                labels = labels_batch.cpu().data.numpy().reshape(b_size, n);
                labels_all.append(labels)
                with torch.no_grad():
                    au = dist.Normal(*model.a_to_z(image.to(device))).sample()
                    az, aw = torch.split(au, [model.z_dim, model.w_dim], dim=-1)
                    latent_rep['image']['us'].append(au.cpu().data.numpy())
                    latent_rep['image']['zs'].append(az.cpu().data.numpy())
                    latent_rep['image']['ws'].append(aw.cpu().data.numpy())
    labels_all = np.concatenate(labels_all, axis=0)
    gt = np.argmax(labels_all, axis=1).astype(int)
    clf_lr = dict();
    latent_rep_au = np.concatenate(latent_rep['image']['us'], axis=0)
    latent_rep_az = np.concatenate(latent_rep['image']['zs'], axis=0)
    latent_rep_aw = np.concatenate(latent_rep['image']['ws'], axis=0)
    clf_lr_rep_au = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=10000)
    clf_lr_rep_az = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=10000)
    clf_lr_rep_aw = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=10000)
    clf_lr_rep_au.fit(latent_rep_au, gt.ravel()[:latent_rep_au.shape[0]])
    clf_lr['image_u'] = clf_lr_rep_au
    clf_lr_rep_az.fit(latent_rep_az, gt.ravel()[:latent_rep_au.shape[0]])
    clf_lr['image_z'] = clf_lr_rep_az
    clf_lr_rep_aw.fit(latent_rep_aw, gt.ravel()[:latent_rep_au.shape[0]])
    clf_lr['image_w'] = clf_lr_rep_aw

    latent_rep_bu = np.concatenate(latent_rep['sentence']['us'], axis=0)
    latent_rep_bz = np.concatenate(latent_rep['sentence']['zs'], axis=0)
    latent_rep_bw = np.concatenate(latent_rep['sentence']['ws'], axis=0)
    clf_lr_rep_bu = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=10000)
    clf_lr_rep_bz = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=10000)
    clf_lr_rep_bw = LogisticRegression(random_state=0, solver='lbfgs', multi_class='auto', max_iter=10000)
    clf_lr_rep_bu.fit(latent_rep_bu, gt.ravel()[:latent_rep_bu.shape[0]])
    clf_lr['sentence_u'] = clf_lr_rep_bu
    clf_lr_rep_bz.fit(latent_rep_bz, gt.ravel()[:latent_rep_bu.shape[0]])
    clf_lr['sentence_z'] = clf_lr_rep_bz
    clf_lr_rep_bw.fit(latent_rep_bw, gt.ravel()[:latent_rep_bu.shape[0]])
    clf_lr['sentence_w'] = clf_lr_rep_bw
    return clf_lr

def classify_latent_representations(clf_lr, data, labels, split=False):
    gt = np.argmax(labels, axis=1).astype(int)
    accuracies = dict()

    y_preds_z = []
    if split:
        data_rep_au, data_rep_az, data_rep_aw = data[0]
        data_rep_bu, data_rep_bz, data_rep_bw = data[1]
    else:
        data_rep_au = data[0][0]
        data_rep_bu = data[1][0]

    clf_lr_rep_au = clf_lr['image_u'];
    y_pred_rep_au = clf_lr_rep_au.predict(data_rep_au);
    accuracy_rep_au = accuracy_score(gt, y_pred_rep_au.ravel());
    accuracies['image_u'] = accuracy_rep_au;

    clf_lr_rep_bu = clf_lr['sentence_u'];
    y_pred_rep_bu = clf_lr_rep_bu.predict(data_rep_bu);
    accuracy_rep_bu = accuracy_score(gt, y_pred_rep_bu.ravel());
    accuracies['sentence_u'] = accuracy_rep_bu;

    if split:
        clf_lr_rep_az = clf_lr['image_z'];
        y_pred_rep_az = clf_lr_rep_az.predict(data_rep_az);
        accuracy_rep_az = accuracy_score(gt, y_pred_rep_az.ravel());
        accuracies['image_z'] = accuracy_rep_az;

        clf_lr_rep_aw = clf_lr['image_w'];
        y_pred_rep_aw = clf_lr_rep_aw.predict(data_rep_aw);
        accuracy_rep_aw = accuracy_score(gt, y_pred_rep_aw.ravel());
        accuracies['image_w'] = accuracy_rep_aw;

        y_preds_z.append(y_pred_rep_az)

        clf_lr_rep_bz = clf_lr['sentence_z'];
        y_pred_rep_bz = clf_lr_rep_bz.predict(data_rep_bz);
        accuracy_rep_bz = accuracy_score(gt, y_pred_rep_bz.ravel());
        accuracies['sentence_z'] = accuracy_rep_bz;

        clf_lr_rep_bw = clf_lr['sentence_w'];
        y_pred_rep_bw = clf_lr_rep_bw.predict(data_rep_bw);
        accuracy_rep_bw = accuracy_score(gt, y_pred_rep_bw.ravel());
        accuracies['sentence_w'] = accuracy_rep_bw;

        y_preds_z.append(y_pred_rep_bz)

    overall_preds = stats.mode(np.stack(y_preds_z))[0]
    accuracy_rep_all = accuracy_score(gt, overall_preds.ravel());
    accuracies['all'] = accuracy_rep_all;

    return accuracies;

def linear_latent_classification(dl, model, clf_lr, device, n=8):
    accuracies_lr = {}
    all_au, all_az, all_aw = [], [], []
    all_bu, all_bz, all_bw = [], [], []
    all_labels = []
    latent_reps = []
    with torch.no_grad():
        for _, (image, sentence, y) in enumerate(dl):
            image, sentence, y = image.to(device), sentence.to(device), y.to(device)
            b_size = y.size(0)
            labels_batch = nn.functional.one_hot(y.long(), num_classes=n).float()
            labels = labels_batch.cpu().data.numpy().reshape(b_size, n)
            
            au = dist.Normal(*model.a_to_z(image.to(device))).sample()
            az, aw = torch.split(au, [model.z_dim, model.w_dim], dim=-1)
            all_au.append(au.cpu().data.numpy())
            all_az.append(az.cpu().data.numpy())
            all_aw.append(aw.cpu().data.numpy())
            bu = dist.Normal(*model.b_to_z(sentence.to(device))).sample()
            bz, bw = torch.split(bu, [model.z_dim, model.w_dim], dim=-1)
            all_bu.append(bu.cpu().data.numpy())
            all_bz.append(bz.cpu().data.numpy())
            all_bw.append(bw.cpu().data.numpy())
            all_labels.append(labels)

        latent_reps.append([np.concatenate(all_au, axis=0), np.concatenate(all_az, axis=0), np.concatenate(all_aw, axis=0)])
        latent_reps.append([np.concatenate(all_bu, axis=0), np.concatenate(all_bz, axis=0), np.concatenate(all_bw, axis=0)])
        all_labels = np.concatenate(all_labels, axis=0)
        
        accuracies = classify_latent_representations(clf_lr, latent_reps, all_labels, split=True)

        accuracies_lr['image_u'] = accuracies['image_u']
        accuracies_lr['sentence_u'] = accuracies['sentence_u']
        accuracies_lr['image_z'] = accuracies['image_z']
        accuracies_lr['sentence_z'] = accuracies['sentence_z']
        accuracies_lr['image_w'] = accuracies['image_w']
        accuracies_lr['sentence_w'] = accuracies['sentence_w']
        accuracies_lr['mean_u'] = mean([accuracies_lr['image_u'], accuracies_lr['sentence_u']])
        accuracies_lr['mean_z'] = mean([accuracies_lr['image_z'], accuracies_lr['sentence_z']])
        accuracies_lr['mean_w'] = mean([accuracies_lr['image_w'], accuracies_lr['sentence_w']])
        accuracies_lr['z_all'] = accuracies['all']

    return accuracies_lr


def cross_coherence(clf, dl, model, device):
    clf.eval()
    clf.to(device)

    corrs = [0, 0]
    total = 0
    with torch.no_grad():
        for _, (image, sentence, y) in enumerate(dl):
            image, sentence, y = image.to(device), sentence.to(device), y.to(device)
            total += y.size(0)
            px_us = [None for _ in range(2)]
            px_us[0] = model.a_to_a(image)
            px_us[1] = model.b_to_a(sentence)
            for idx in range(2):
                clfs_results = torch.argmax(clf(px_us[idx]), dim=-1)
                corrs[idx] += (clfs_results == y).sum().item()
        for idx in range(2):
            corrs[idx] = corrs[idx] / total
        
    return corrs


def calculate_inception_features_for_gen_evaluation(inception_state_dict_path, device, dir_fid_base, datadir, dims=2048, batch_size=128):
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx], path_state_dict=inception_state_dict_path)
    model = model.to(device)

    moddality = 'image'
    filename_act_real_calc = os.path.join(datadir, 'real_activations_{}.npy'.format(moddality))
    if not os.path.exists(filename_act_real_calc):
        files_real_calc = glob.glob(os.path.join(datadir, moddality, '*' + '.png'))
        act_real_calc = get_activations(files_real_calc, model, device, batch_size, dims, verbose=False)
        np.save(filename_act_real_calc, act_real_calc)

    for prefix  in ['image', 'sentence']:
        dir_gen = os.path.join(dir_fid_base, prefix)
        if not os.path.exists(dir_gen):
            raise RuntimeError('Invalid path: %s' % dir_gen)
        
        files_gen = glob.glob(os.path.join(dir_gen, moddality, '*' + '.png'))
        filename_act = os.path.join(dir_gen, moddality + '_activations.npy')
        act_rand_gen = get_activations(files_gen, model, device, batch_size, dims, verbose=False)
        np.save(filename_act, act_rand_gen)

def calculate_fid(feats_real, feats_gen):
    mu_real = np.mean(feats_real, axis=0)
    sigma_real = np.cov(feats_real, rowvar=False)
    mu_gen = np.mean(feats_gen, axis=0)
    sigma_gen = np.cov(feats_gen, rowvar=False)
    fid = calculate_frechet_distance(mu_real, sigma_real, mu_gen, sigma_gen)
    return fid;

def calculate_fid_routine(dl, model, datadir, fid_path, inception_path, num_fid_samples, device):
    total_cond = 0
    for j in ['image']:
        for i in ['image', 'sentence']:
            if os.path.exists(os.path.join(fid_path, i, j)):
                shutil.rmtree(os.path.join(fid_path, i, j))
                os.makedirs(os.path.join(fid_path, i, j))
            else:
                os.makedirs(os.path.join(fid_path, i, j))

    with torch.no_grad():
        for i, data in enumerate(dl):
            if total_cond < num_fid_samples:
                model.self_and_cross_modal_generation_for_fid_calculation(data, datadir, fid_path, i)
                total_cond += data[0].size(0)
        calculate_inception_features_for_gen_evaluation(inception_path, device, fid_path, datadir)
        fid_condgen_list = []
        modality_target = 'image'
        
        file_activations_real = os.path.join(datadir, 'real_activations_{}.npy'.format(modality_target))
        feats_real = np.load(file_activations_real)
        
        fid_condgen_target_list = []
        for modality_source in ['image', 'sentence']:
            file_activations_gen = os.path.join(fid_path, modality_source, modality_target + '_activations.npy')
            feats_gen = np.load(file_activations_gen)
            fid_val = calculate_fid(feats_real, feats_gen)
            print("FID/{}/{}: ".format(modality_source, modality_target) + str(fid_val))
            fid_condgen_target_list.append(fid_val)
        fid_condgen_list.append(mean(fid_condgen_target_list))
        mean_fid_condgen = mean(fid_condgen_list)
        print("FID/condgen_meanll: " + str(mean_fid_condgen))

    if os.path.exists(fid_path):
        shutil.rmtree(fid_path)
        os.makedirs(fid_path)