import logging
import torch
import numpy as np
from src.utils.seed import set_seed
from sklearn import linear_model
import pdb

logger = logging.getLogger(__name__)


def compute_beta_vae(dataset, model, batch_size, num_train, num_eval, loss_fn, args):
    logger.info(
        "*********************Beta-VAE Disentanglement Evaluation*********************"
    )

    train_points, train_labels = generate_training_batch(
        dataset, model, batch_size, num_train, loss_fn, args
    )
    set_seed(args)
    regression_model = linear_model.LogisticRegression(max_iter=200)
    regression_model.fit(train_points, train_labels)

    train_accuracy = regression_model.score(train_points, train_labels)
    # train_accuracy = np.mean(regression_model.predict(train_points) == train_labels)
    logging.info("Training set accuracy: %.2g", train_accuracy)

    eval_points, eval_labels = generate_training_batch(
        dataset, model, batch_size, num_eval, loss_fn, args
    )
    eval_accuracy = regression_model.score(eval_points, eval_labels)
    return eval_accuracy


def generate_training_batch(dataset, model, batch_size, num_points, loss_fn, args):
    points = None  # Dimensionality depends on the representation function.
    labels = np.zeros(num_points, dtype=np.int64)
    set_seed(args)
    for i in range(num_points):
        labels[i], feature_vector = generate_training_sample(
            dataset, model, batch_size, loss_fn, args
        )
        if points is None:
            points = np.zeros((num_points, feature_vector.shape[0]))
        points[i, :] = feature_vector
    return points, labels


def generate_training_sample(dataset, model, batch_size, loss_fn, args):
    # select random coordinate to keep fixed.
    fixed_index = np.random.randint(dataset.factor_num)
    # Sample two mini batches of latent variables.
    length = len(dataset)
    idx_1 = np.random.choice(length - 1, batch_size, replace=False)
    idx_2 = np.random.choice(length - 1, batch_size, replace=False)
    # Sample two mini batches of latent variables.
    factors_1 = dataset.latents_classes[idx_1]
    factors_2 = dataset.latents_classes[idx_2]
    # Ensure sampled coordinate is the same across pairs of samples.
    factors_2[:, fixed_index] = factors_1[:, fixed_index]
    # Select index from factors
    changed_idx = find_index_from_factors(factors_2, dataset).astype(np.int64)
    # Select imgaes from idx
    imgs_1, imgs_2 = [], []
    for id in idx_1:
        imgs_1.append(dataset.__getitem__(id)[0])
    for id in changed_idx:
        imgs_2.append(dataset.__getitem__(id)[0])
    imgs_1 = torch.stack(imgs_1, dim=0).to(next(model.parameters()).device)
    imgs_2 = torch.stack(imgs_2, dim=0).to(next(model.parameters()).device)

    latent_vector_1 = model.encoder(imgs_1)[0][:batch_size]
    if 'cmcs' in args.model_type:
        latent_vector_1 = model.real_to_theta(latent_vector_1)
        latent_vector_1 = model.select_code(latent_vector_1)
    # elif 'cmcs_unsuper' == args.model_type:
    #     latent_vector_1 = model.group_action(latent_vector_1, model.n / 100.0)
    #     latent_vector_1 = model.select_code(latent_vector_1)

    latent_vector_2 = model.encoder(imgs_2)[0][:batch_size]
    if 'cmcs' in args.model_type:
        latent_vector_2 = model.real_to_theta(latent_vector_2)
        latent_vector_2 = model.select_code(latent_vector_2)
    # elif 'cmcs_unsuper' == args.model_type:
    #     latent_vector_2 = model.group_action(latent_vector_2, model.n / 100.0)
    #     latent_vector_2 = model.select_code(latent_vector_2)

    feature_vector = torch.mean(torch.abs(latent_vector_1 - latent_vector_2), dim=-2)
    feature_vector = feature_vector.detach().cpu().numpy()  # (latent_dim)
    return fixed_index, feature_vector


def find_index_from_factors(factors, dataset):
    base = []
    length = len(dataset.factor_dict)
    for i in range(length):
        value = 1
        for j in range(i+1, length):
            value *= dataset.factor_dict[j]
        base.append(value)
    base = np.array(base)
    idx = np.dot(factors, base)
    return idx


