from DMM import DMM
from DMM_categorical import DMM_categorical
import time
import pyro
import numpy as np
from sklearn.cluster import KMeans
import torch
from torch.nn.utils.rnn import pad_sequence
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions import TransformedDistribution
from pyro.distributions.transforms import affine_autoregressive
from pyro.optim import ClippedAdam
from torch.optim import Adam
import sys
import os
import argparse
import pyro.contrib.examples.polyphonic_data_loader as poly
from pyro.infer import (
    SVI,
    JitTrace_ELBO,
    Trace_ELBO,
    TraceEnum_ELBO,
    TraceTMC_ELBO,
    config_enumerate,
)
sys.path.append('../../../')
from utils import evaluate_dist, line_plot, load_data

class Trainer:
    def __init__(self, svi, training_data_sequences, training_seq_lengths, cuda):
        self.svi = svi
        self.training_data_sequences = training_data_sequences
        self.training_seq_lengths = training_seq_lengths
        self.cuda=cuda
    
    def process_minibatch(self, epoch, which_mini_batch, shuffled_indices, mini_batch_size):
        """
        if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
            # compute the KL annealing factor appropriate
            # for the current mini-batch in the current epoch
            min_af = args.minimum_annealing_factor
            annealing_factor = min_af + (1.0 - min_af) * \
                (float(which_mini_batch + epoch * N_mini_batches + 1) /
                float(args.annealing_epochs * N_mini_batches))
        else:
            # by default the KL annealing factor is unity
            annealing_factor = 1.0
        """
        # compute which sequences in the training set we should grab
        mini_batch_start = (which_mini_batch * mini_batch_size)
        mini_batch_end = np.min([(which_mini_batch + 1) * mini_batch_size,
                                len(self.training_data_sequences)])

        mini_batch_indices = torch.Tensor(shuffled_indices[mini_batch_start:mini_batch_end]).long()
        # grab the fully prepped mini-batch using the helper function in the data loader
        
        mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
            = poly.get_mini_batch(mini_batch_indices, self.training_data_sequences,
                                self.training_seq_lengths, cuda=self.cuda)
        # do an actual gradient step
        loss = self.svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
                        mini_batch_seq_lengths)
        # keep track of the training loss
        return loss

    def compute_log_likelihood(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths):
        # Compute the log likelihood (negative ELBO) for the given mini-batch
        loss = self.svi.evaluate_loss(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths)
        return loss



def train_model(x_train, z_train, train_lens, x_valid, z_valid, valid_lens, x_test, z_test, test_lens, input_dim, z_dim, transition_dim, rnn_dim, lr, epochs, clip_norm, args, n_cats=None):
    cuda = torch.cuda.is_available()
    #if z_dim == 1:
    #    dmm = DMM_categorical(input_dim, z_dim, transition_dim, rnn_dim, 0.0, 20, cuda)
    #else:
    dmm = DMM(input_dim, z_dim, transition_dim, rnn_dim, 0.0, cuda)

    all_train, all_train_reversed, all_train_mask, all_train_seq_lengths = poly.get_mini_batch(np.arange(len(x_train)), x_train, train_lens, cuda=cuda)
    all_valid, all_valid_reversed, all_valid_mask, all_valid_seq_lengths = poly.get_mini_batch(np.arange(len(x_valid)), x_valid, valid_lens, cuda=cuda)                            
    all_test, all_test_reversed, all_test_mask, all_test_seq_lengths = poly.get_mini_batch(np.arange(len(x_test)), x_test, test_lens, cuda=cuda)

    # setup optimizer
    adam_params = {"lr": lr}
                #"clip_norm": clip_norm}
    optimizer = ClippedAdam(adam_params)
    N_mini_batches = len(x_train)//32
    # setup inference algorithm
    svi = SVI(dmm.model, dmm.guide, optimizer, Trace_ELBO())
    trainer = Trainer(svi, x_train, train_lens, cuda)
    times = [time.time()]
    train_nlls = []
    valid_nlls = []
    test_nlls = []
    val_hammings = []
    test_hammings = []
    for epoch in range(epochs):
        # accumulator for our estimate of the negative log likelihood
        # (or rather -elbo) for this epoch
        epoch_train_nll = 0.0
        epoch_val_nll = 0.0
        epoch_test_nll = 0.0
        # prepare mini-batch subsampling indices for this epoch
        shuffled_indices = list(np.arange(len(x_train)))
        np.random.shuffle(shuffled_indices)

        # process each mini-batch; this is where we take gradient steps
        for which_mini_batch in range(N_mini_batches):
            epoch_train_nll += trainer.process_minibatch(epoch, which_mini_batch, shuffled_indices, mini_batch_size=32)
        
        epoch_val_nll += trainer.compute_log_likelihood(all_valid, all_valid_reversed, all_valid_mask, all_valid_seq_lengths)
        epoch_test_nll += trainer.compute_log_likelihood(all_test, all_test_reversed, all_test_mask, all_test_seq_lengths)
        # report training diagnostics
        times.append(time.time())
        epoch_time = times[-1] - times[-2]
        train_nlls.append(epoch_train_nll / x_train.shape[1])
        valid_nlls.append(epoch_val_nll / x_valid.shape[1])
        test_nlls.append(epoch_test_nll / x_test.shape[1])

        print("[training epoch %04d]  Train NLL: %.4f \t Valid NLL: %.4f \t Test NLL: %.4f \t\t\t\t(dt = %.3f sec)" %
            (epoch, train_nlls[-1], valid_nlls[-1], test_nlls[-1], epoch_time))
        if epoch % 10 == 0:
            line_plot(train_nlls, f'{args.data}/{args.short_name}/train_NLL.pdf', 'Train NLL', 'epoch', 'NLL')
            line_plot(valid_nlls, f'{args.data}/{args.short_name}/valid_NLL.pdf', 'Valid NLL', 'epoch', 'NLL')
            line_plot(test_nlls, f'{args.data}/{args.short_name}/test_NLL.pdf', 'Test NLL', 'epoch', 'NLL')
            dmm.eval()
            with torch.no_grad():
                kmeans = KMeans(n_clusters=n_cats)
                train_zs = torch.stack(dmm.guide(all_train, all_train_reversed, all_train_mask, all_train_seq_lengths, inference=True))
                train_latents = train_zs.detach().numpy()
                train_state_preds = kmeans.fit(train_latents.reshape(-1, train_latents.shape[-1])).labels_.reshape(x_train.shape[0], x_train.shape[1])
                train_hamming, mapper = evaluate_dist(z_train, train_state_preds, train_lens, n_cats)
                
                
                valid_zs = torch.stack(dmm.guide(all_valid, all_valid_reversed, all_valid_mask, all_valid_seq_lengths, inference=True))
                valid_latents = valid_zs.detach().numpy()
                val_state_preds = kmeans.predict(valid_latents.reshape(-1, valid_latents.shape[-1])).reshape(x_valid.shape[0], x_valid.shape[1])
                val_hamming, _ = evaluate_dist(z_valid, val_state_preds, valid_lens, n_cats, mapper)

                test_zs = torch.stack(dmm.guide(all_test, all_test_reversed, all_test_mask, all_test_seq_lengths, inference=True))
                test_latents = test_zs.detach().numpy()
                test_state_preds = kmeans.predict(test_latents.reshape(-1, test_latents.shape[-1])).reshape(x_test.shape[0], x_test.shape[1])
                test_hamming, _ = evaluate_dist(z_test, test_state_preds, test_lens, n_cats, mapper)
                val_hammings.append(val_hamming)
                test_hammings.append(test_hamming)
                print(f"Train Hamming: {np.mean(train_hamming):{6}.5f} | Valid Hamming: {np.mean(val_hamming):{6}.5f}, Test Hamming: {np.mean(test_hamming):{6}.5f}")
                np.savez('./%s/%s/checkpoint_'%(args.data, args.short_name) + str(args.cv) + '.npz', loglik=[-1*loss for loss in test_nlls], val_loglik=[-1*loss for loss in valid_nlls], 
                validation_hammings=val_hammings, test_hammings=test_hammings);
            dmm.train() # put back in training mode
    

def main(args, data_load_config):
    print('Selecting data...')
    x_train, z_train, train_lens, x_valid, z_valid, valid_lens, x_test, z_test, test_lens = load_data(data_type=args.data, config=data_load_config, normalize=False, pad_ragged=True, path_to_data='../../../data/')
    if type(x_train) == np.ndarray:
        x_train = torch.Tensor(x_train)
        x_valid = torch.Tensor(x_valid)
        x_test = torch.Tensor(x_test)
        z_train = torch.Tensor(z_train)
        z_valid = torch.Tensor(z_valid)
        z_test = torch.Tensor(z_test)
    print('type(x_train): ', type(x_train))
    print('type(x_train[0]): ', type(x_train[0]))
    train_lens = torch.Tensor(train_lens).int()
    valid_lens = torch.Tensor(valid_lens).int()
    test_lens = torch.Tensor(test_lens).int()
    
    D = x_train[0].shape[-1]
    train_model(x_train=x_train, 
                z_train=z_train,
                train_lens=train_lens,
                x_valid=x_valid,
                z_valid=z_valid,
                valid_lens=valid_lens,
                x_test=x_test,
                z_test=z_test,
                test_lens=test_lens,
                input_dim=D, 
                z_dim=args.z_dim,  
                transition_dim=args.transition_dim, 
                rnn_dim=args.rnn_dim, 
                epochs=args.epochs,
                lr=args.lr, 
                clip_norm=args.clip_norm,
                args=args,
                n_cats=args.n_cats)


if __name__ == '__main__':
    # list files in current directory: 
    parser = argparse.ArgumentParser(description='Run DMM')
    parser.add_argument('--data', type=str, default='sim_hard')

    # Training hyperparams
    parser.add_argument('--epochs', type=int)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--clip_norm', type=float)

    # Model hyperparams
    parser.add_argument('--z_dim', type=int)
    parser.add_argument('--rnn_dim', type=int)
    parser.add_argument('--transition_dim', type=int)

    # Misc
    parser.add_argument('--n_cats', type=int)
    parser.add_argument('--cv', type=int)
    parser.add_argument('--short_name', type=str)
    parser.add_argument('--ds_factor', type=int)


    args = parser.parse_args()
    data_load_config = {'sim_easy': {'n_train': 100, 'n_valid': 50, 'n_test': 50}, 
                        'sim_hard': {'n_train': 100, 'n_valid': 50, 'n_test': 50},
                        'sim_semi_markov': {'n_train': 60, 'n_valid': 20, 'n_test': 20},
                        'har': {'ds_factor': args.ds_factor},
                        'har_70': {'ds_factor': args.ds_factor},
                        }                   
    main(args, data_load_config)




