import os
import argparse
from configs.defaults import get_cfgs_defaults
import torch
import torch.nn.functional as F
from torch.utils.data import ConcatDataset, DataLoader
from torch.utils.data.dataset import Subset
from torch.distributions import Categorical
from torchvision import datasets, transforms

from tqdm import tqdm
import pickle
import warnings
import math
from model import GaussianSQVAE, SQVAE

import numpy as np
from sklearn.feature_selection import mutual_info_classif


class NestedDict(dict):
    def __missing__(self, key):
        self[key] = type(self)()
        return self[key]

def arg_parse():
    parser = argparse.ArgumentParser(
            description="ecmi_parse_results.py")
    parser.add_argument(
        "-c", "--config_file", default="", help="config file")
    parser.add_argument(
        "-data", "--dataset", default="MNIST", help="dataset")
    parser.add_argument(
        "--save", action="store_true", help="save trained model")
    parser.add_argument(
        "--dbg", action="store_true", help="print losses per epoch")
    parser.add_argument(
        "--gpu", default="0", help="index of gpu to be used")
    parser.add_argument(
        '--ecmi', action='store_true', help="determine whether we use the eCMI settings or not")
    parser.add_argument(
        '--n', '-n', type=int, default=None, help='Number of training examples for the eCMI settings')
    parser.add_argument(
        '--K', '-K', type=int, default=128, help='Number of size_dict')
    parser.add_argument(
        '--d_dict', type=int, default=64, help='Number of dim_dict')
    parser.add_argument(
        "--seed", type=int, default=0, help="seed number for randomness (on 2n examples if we use the eCMI settings)")
    parser.add_argument(
        "--S_seed", type=int, default=0, help="seed number for randomness on train/test data split (ecmi)")
    parser.add_argument('--eval_K_variation', type=bool, default=False, help='Evaluate models under K variation')
    args = parser.parse_args()
    return args


def load_config(args, seed):
    cfgs = get_cfgs_defaults()
    config_path = os.path.join(os.path.dirname(__file__), "configs", args.config_file)
    print(config_path)
    cfgs.merge_from_file(config_path)
    cfgs.train.seed = seed
    cfgs.quantization.size_dict = args.K
    cfgs.quantization.dim_dict = args.d_dict
    cfgs.flags.save = args.save
    cfgs.flags.noprint = not args.dbg
    cfgs.path_data = cfgs.path
    cfgs.path = os.path.join(cfgs.path, cfgs.path_specific)
    if cfgs.model.name.lower() == "vmfsqvae":
        cfgs.quantization.dim_dict += 1
    cfgs.flags.var_q = not(cfgs.model.param_var_q=="gaussian_1" or
                                        cfgs.model.param_var_q=="vmf")
    cfgs.freeze()
    flgs = cfgs.flags
    return cfgs, flgs

def calc_encode_kld(model, z_from_encoder, var_q, codebook, flg_quant_det=True):
    bs, dim_z, width, height = z_from_encoder.shape
    z_from_encoder_permuted = z_from_encoder.permute(0, 2, 3, 1).contiguous()
    precision_q = 1. / torch.clamp(var_q, min=1e-10)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    logit = -model.quantizer._calc_distance_bw_enc_codes(z_from_encoder_permuted, codebook, 0.5 * precision_q)
    probabilities = torch.softmax(logit, dim=-1)

    if flg_quant_det:
        indices = torch.argmax(logit, dim=1).unsqueeze(1)
        encodings_hard = torch.zeros(indices.shape[0], model.quantizer.size_dict, device=device)
        encodings_hard.scatter_(1, indices, 1)
    else:
        dist = Categorical(probabilities)
        indices = dist.sample().view(bs, width, height)
        encodings_hard = F.one_hot(indices, num_classes=model.quantizer.size_dict).type_as(codebook)

    z_quantized = torch.matmul(encodings_hard, codebook).view(bs, width, height, dim_z)
    z_to_decoder = z_quantized.permute(0, 3, 1, 2).contiguous()
    
    kld_continuous = model.quantizer._calc_distance_bw_enc_dec(z_from_encoder, z_to_decoder, 0.5 * precision_q)

    return encodings_hard, kld_continuous, z_to_decoder

def get_loader_evaluation(args, seed, dataset, path_dataset, bs=64, n_work=2):
    if dataset == "MNIST" or  dataset == "FashionMNIST":
        preproc_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        all_size = 60000
        
        # select 2n examples (tilde{z})
        assert all_size >= 2 * args.n
        np.random.seed(seed)
        include_indices = np.random.choice(range(all_size), size=2 * args.n, replace=False)

        trainval_dataset = eval("datasets."+dataset)(
                os.path.join(path_dataset, "{}/".format(dataset)),
                train=True, download=True, transform=preproc_transform
        )
        all_examples = Subset(trainval_dataset, include_indices)
        train_val_loader = torch.utils.data.DataLoader(
            all_examples, batch_size=bs, shuffle=False,
            num_workers=n_work, pin_memory=False
        )

    elif dataset == "CIFAR10":
        preproc_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        all_size = 50000

        # select 2n examples (tilde{z})
        assert all_size >= 2 * args.n
        np.random.seed(seed)
        include_indices = np.random.choice(range(all_size), size=2 * args.n, replace=False)

        trainval_dataset = datasets.CIFAR10(
                os.path.join(path_dataset, "{}/".format(dataset)), train=True, download=True,
                transform=preproc_transform
        )
        all_examples = Subset(trainval_dataset, include_indices)

        train_val_loader = torch.utils.data.DataLoader(
            all_examples, batch_size=bs, shuffle=False,
            num_workers=n_work, pin_memory=False
        )

    # select the train/val split (S)
    np.random.seed(args.S_seed)
    mask = np.random.randint(2, size=(args.n,)) ## Ber(1/2)
    subset1_indices = 2*np.arange(args.n) + mask
    subset2_indices = 2*np.arange(args.n) + (1-mask)

    return train_val_loader, (subset1_indices, subset2_indices)

def model_eval(x, model, device, flg_quant_det=True):
    if model.param_var_q == "vmf":
        z_from_encoder = F.normalize(model.encoder(x), p=2.0, dim=1)
        param_q = (model.log_param_q_scalar.exp() + torch.tensor([1.0], device=device))
    else:
        if model.param_var_q == "gaussian_1":
            z_from_encoder = model.encoder(x)
            log_var_q = torch.tensor([0.0], device=device)
        else:
            z_from_encoder, log_var = model.encoder(x)
            if model.param_var_q == "gaussian_2":
                log_var_q = log_var.mean(dim=(1,2,3), keepdim=True)
            elif model.param_var_q == "gaussian_3":
                log_var_q = log_var.mean(dim=1, keepdim=True)
            elif model.param_var_q == "gaussian_4":
                log_var_q = log_var
            else:
                raise Exception("Undefined param_var_q")
        param_q = (log_var_q.exp() + model.log_param_q_scalar.exp())

    enc_hard, kld, z_to_decoder = calc_encode_kld(model, z_from_encoder, param_q, model.codebook, flg_quant_det=flg_quant_det)
    x_reconst = model.decoder(z_to_decoder)
    loss = F.mse_loss(x_reconst, x, reduction='none').view(-1, x.shape[1]*x.shape[2]*x.shape[3]).sum(1)

    return enc_hard, kld, loss

def calc_model_eval(data_loader, model, device, flg_quant_det=True):
    enc_hard_list = []
    kld_list = []
    loss_list = []
    
    with torch.no_grad():
        for x, _ in data_loader:
            x = x.to(device)
            enc_hard, kld, loss = model_eval(x, model, device, flg_quant_det=flg_quant_det)
            
            enc_hard_list.append(enc_hard)
            kld_list.append(kld)
            loss_list.append(loss)
        
        kld_list = torch.concat(kld_list)
        enc_hard_list = torch.concat(enc_hard_list) #.reshape(2*args.n, 7, 7, -1)
        loss_list = torch.concat(loss_list)
    
    return enc_hard_list, kld_list, loss_list

def get_fcmi_results_for_fixed_z(args, cfgs, flgs):
    ecmi_list = []
    kl_list = []
    train_loss_list = []
    test_loss_list = []

    for S_seed in range(args.S_seed):

        ## Load the trained model
        #dir_name = f"ecmi_{args.network_name}, n={args.n}, seed={cfgs.train.seed}, S_seed={S_seed}, K={cfgs.quantization.size_dict}, d_dict={cfgs.quantization.dim_dict}"
        #dir_path = os.path.join(args.checkpoint_dir, args.exp_name, dir_name)
        dir_name = f"ecmi_{cfgs.network.name}, n={args.n}, seed={cfgs.train.seed}, S_seed={S_seed}, K={cfgs.quantization.size_dict}, d_dict={cfgs.quantization.dim_dict}"
        dir_path = os.path.join(cfgs.path, dir_name)
        if not os.path.exists(dir_path):
            print(f"Did not find results for {dir_name}")
            continue

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        with open(os.path.join(dir_path, 'current.pt'), 'rb') as f:
            state_dict = torch.load(f, map_location=device)
        
        new_state_dict = {}
        for key, value in state_dict.items():
            new_key = key.replace('module.', '')
            new_state_dict[new_key] = value

        model = GaussianSQVAE(cfgs, flgs)
        model.load_state_dict(new_state_dict)
        model = model.to(device)

        ## Set train_val_dataset and indexes of supersamples
        train_val_loader, idx = get_loader_evaluation(args, cfgs.train.seed, cfgs.dataset.name, cfgs.path_dataset, bs=cfgs.test.bs, n_work=cfgs.nworker)

        ## Model evaluate
        enc_hard, kld, loss = calc_model_eval(train_val_loader, model, device, flg_quant_det=True)

        kl_list.append(kld.cpu().sum()) ## KL
        train_loss_list.append(loss[idx[0]].cpu().mean()) ## train_loss/n
        test_loss_list.append(loss[idx[1]].cpu().mean()) ## test_loss/n

        ## set mask
        ms = torch.zeros(2*args.n)
        ms[idx[0]] = 0
        ms[idx[1]] = 1

        ## eval. eCMI
        cur_mi = max(0, mutual_info_classif(enc_hard.view(2*args.n,-1).cpu(), ms, discrete_features=False).sum())
        print("eCMI:", cur_mi)
        ecmi_list.append(cur_mi)

    return {
        "ecmi": np.array(ecmi_list).mean(), ## empirical mean of ecmi w.r.t. S_seed
        "kl": np.array(kl_list).mean(), ## empirical mean of kl w.r.t. S_seed
        "train_loss": np.array(train_loss_list).mean(), ## empirical mean of "empirical tr. loss" w.r.t. S_seed
        "test_loss": np.array(test_loss_list).mean(), ## empirical mean of "empirical te. loss" w.r.t. S_seed
    }

def get_ecmi_results_for_fixed_model(args):
    result = []
    for seed in range(args.seed):    
        ## Load the experimental settings/configs
        cfgs = get_cfgs_defaults()
        cfgs, flgs = load_config(args, seed)
        cur = get_fcmi_results_for_fixed_z(args, cfgs, flgs)
        result.append(cur)
    
    return result

def main():
    ## Experimental setup
    args = arg_parse()
    args.ecmi = True
    if args.gpu != "":
        if torch.cuda.is_available():
            os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
        else:
            pass
    print(args)
    checkpoint_dir = "checkpoint"
    
    if args.dataset in ["MNIST", "FashionMNIST"]:
        if args.eval_K_variation:
            ns = [4000]
            args.seed = 3
            args.S_seed = 5
            #args.seed = 1
            #args.S_seed = 1
        else:
            ns = [250, 1000, 2000, 4000]
            args.seed = 3
            args.S_seed = 5
        exp_name = 'mnist_sqvae_gaussian_1'
    elif args.dataset in ["CIFAR10"]:
        ns = [1000, 5000, 10000, 20000]
        args.seed = 3
    
    results = NestedDict()  # indexing with n, epoch
    for n in tqdm(ns):
        args.n = n
        results[args.n] = get_ecmi_results_for_fixed_model(args)
    
    #results_file_path = os.path.join(args.checkpoint_dir, args.exp_name, 'results_ecmi_{}.pkl'.format(args.exp_name))
    if args.eval_K_variation:
        results_file_path = os.path.join(checkpoint_dir, exp_name, 'results_ecmi_{0}_{1}.pkl'.format(exp_name, args.K))
    else:
        results_file_path = os.path.join(checkpoint_dir, exp_name, 'results_ecmi_{}.pkl'.format(exp_name))

    with open(results_file_path, 'wb') as f:
        pickle.dump(results, f)

if __name__ == '__main__':
    main()