import torch
from model import *
from dataloader import *
import numpy as np
import argparse
from pytorch_models.utils.logger import *
import random
from preprocess import *
from extract_dfa import *
from utils import *


def get_params():
    parser = argparse.ArgumentParser(description='tomita_lstm')
    parser.add_argument('--suffix', default="")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--state_dim', type=int, default=32)
    parser.add_argument('--load_model', default="")
    parser.add_argument('--model', default="rnn")
    parser.add_argument('--rnn_cell', default=None)
    parser.add_argument('--hinit', type = int, default= 0)
    parser.add_argument('--mtype', default="char_prediction")
    parser.add_argument('--dropout', type=float, default=0)
    parser.add_argument('--nlayers', type=int, default=1)
    parser.add_argument('--dataset', type = str, default="msi")
    parser.add_argument('--grammar', type = str, default="tomita-1")
    parser.add_argument('--dfa_min_state', type = int, default=1)
    parser.add_argument('--dfa_max_state', type = int, default=100)
    parser.add_argument('--dfa_state_step', type = int, default=1)
    parser.add_argument('--cluster_method', type = str, default="som")
    parser.add_argument('--batch_size', type = int, default="1024")
    parser.add_argument('--valid_batch_size', type = int, default = -1)
    parser.add_argument('--normalize_states', action = "store_true")
    parser.add_argument('--tsne', action = "store_true")
    parser.add_argument('--tsne_fname', type = str, default = None)
    parser.add_argument('--pos_data_only', action = "store_true")
    parser.add_argument('--partial_strings', action = "store_true")
    parser.add_argument('--remove_start', action = "store_true")
    parser.add_argument('--remove_end', action = "store_true")
    parser.add_argument('--shuffle_data', action = "store_true")
    args = parser.parse_args()
    return args


def get_hidden_states(model, dataset, device, partial_strings = False, remove_start=False, remove_end=False, only_correct_preds = False):
    x = []
    hlog = []
    mlog = []
    x_len = []
    y = []
    ## partial strings only work when end is removed.

    for batch_data in dataset:
        input_seq, out_seq, seq_len  = [item.to(device) for item in batch_data]
        hidden = model.get_hidden(input_seq, seq_len)
    
        if remove_end:
            seq_len = seq_len - 1

        if partial_strings:
            assert remove_end
            for sample_idx in range(input_seq.shape[0]):
                _inp = input_seq[sample_idx]
                _out = out_seq[sample_idx]
                _slen = seq_len[sample_idx]
                _hidden = hidden[sample_idx]

                mask =  1-torch.triu(torch.ones([_slen.item(), _inp.shape[0]]), diagonal = 1).to(device)
                _x = _inp.unsqueeze(dim = 0).expand(mask.shape[0], mask.shape[1])
                _x = _x*mask
                _x = _x.type(torch.long)
                _h = _hidden.unsqueeze(dim = 0).expand(mask.shape[0], _hidden.shape[0], _hidden.shape[1])
                _sl = torch.arange(1, _slen +1)
                _y = _out[:_slen]

                if only_correct_preds:
                    logits, _ = model.model(_inp, _slen+1)
                    y_pred = logits.round().squeeze()[:_slen]
                    assert _y.shape == y_pred.shape
                    crr_indices = (y_pred == _y).nonzero().squeeze()
                    
                    _x = _x[crr_indices]
                    _h = _h[crr_indices]
                    mask = mask[crr_indices]
                    x_len = x_len[crr_indices]
                    _y = _y[crr_indices]

                
                x.append(_x.detach().cpu())
                hlog.append(_h.detach().cpu())
                mlog.append(mask.detach().cpu())
                x_len.append(_sl.detach().cpu())
                y.append(_y.detach().cpu())


        else:
            _mask = torch.arange(input_seq.shape[1]).to(device)
            mask = _mask[None] <= (seq_len-1)[:, None]
            ymask = _mask[None] == (seq_len-1)[:, None]
            if out_seq.dim() > 1:
                _y = torch.masked_select(out_seq, ymask)
            else:
                _y = out_seq
            torch.set_printoptions(threshold=10_000)
            if only_correct_preds:
                logits, _ = model.model(input_seq, seq_len+1)
               
                assert logits.shape[0] == input_seq.shape[0]
                y_pred = logits.round().squeeze()
                #print(y_pred)
                
                assert _y.shape == y_pred.shape
                crr_indices = (_y == y_pred).nonzero().squeeze()
                accuracy = torch.mean((_y == y_pred).type(torch.float))
                print("train accuracy : ", accuracy)
                input_seq = input_seq[crr_indices]
                hidden = hidden[crr_indices]
                seq_len = seq_len[crr_indices]
                mask = mask[crr_indices]
                _y = _y[crr_indices]

            x.append(input_seq.detach().cpu())
            hlog.append(hidden.detach().cpu())
            mlog.append(mask.detach().cpu())
            x_len.append(seq_len.detach().cpu())
            y.append(_y.detach().cpu())


    x = torch.cat(x)
    hlog = torch.cat(hlog)
    mlog = torch.cat(mlog)
    x_len = torch.cat(x_len)
    y = torch.cat(y)

    print(x.shape)
    print(torch.sum(y))

    #--------------------------------------------------------
    # we need to pair each x with input hidden state rather than output hidden state, 
    # thus we need to move states one step right 
    #--------------------------------------------------------

    hlog = torch.cat([torch.ones([hlog.shape[0], 1, hlog.shape[-1]]), hlog], dim = 1)

    if remove_start:
        x = x[:, 1:]
        hlog = hlog[:, 1:]
        mlog = mlog[:, 1:]
        x_len = x_len - 1 

    #------ remove any empty strings ------
    indices = x_len.nonzero().squeeze()
    x = x[indices]
    hlog = hlog[indices]
    mlog = mlog[indices]
    x_len = x_len[indices]
    y = y[indices]

    print(x.shape, hlog.shape, mlog.shape, x_len.shape, y.shape)


    # There should be one more state than than the input
    assert x.shape[1] + 1 == hlog.shape[1]
    
    print(x[:10])
    print(y[:10])

    return x, x_len, hlog, mlog, y


if __name__ == "__main__":
    session_args = get_params()

    torch.manual_seed(session_args.seed)
    np.random.seed(session_args.seed)
    random.seed(session_args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    model_name = "dfa_{}".format(session_args.model)
    suffix = "{0}_{1}_{2}_seed_{3}".format(model_name, session_args.suffix, session_args.cluster_method, session_args.seed)
    if session_args.normalize_states:
        suffix += "_ns"

    lfname = '../logs/{0}.log'.format(suffix)
    
    logger = create_logger(suffix, lfname, seed = session_args.seed)

    logger.info(session_args)

    print(session_args)

    #================================================================================================================================
    # Dataset, vocab 
    #================================================================================================================================
    batch_size = session_args.batch_size
    if session_args.valid_batch_size == -1:
        session_args.valid_batch_size = batch_size

    vocab = get_vocab(session_args)
    session_args.vocab_size = len(vocab)
    
    train_data, val_data = load_dataset(session_args, logger)

    #================================================================================================================================
    # Model 
    #================================================================================================================================

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    
    session_args.device = device

    model = load_model(session_args)

    _n_params = model.count_parameters()
    logger.info("Model Size : {}".format(_n_params))


    checkpoint = torch.load(session_args.load_model)
    model.load_state_dict(checkpoint['model_state'])

    #================================================================================================================================
    # Extract Hidden from model 
    #================================================================================================================================

    x, x_len, hidden, mask, y = get_hidden_states(model, train_data, device, 
                                    remove_start = session_args.remove_start, 
                                    remove_end=session_args.remove_end, 
                                    partial_strings = session_args.partial_strings, only_correct_preds=True)

    alphabet = {v : k for k,v in vocab.items()}
    print(alphabet)
    # --------- remove Start End symbols from alphabet -------------------------
    
    if session_args.grammar[:6] == "tomita":
        # remove PAD
        del alphabet[0]
        if session_args.remove_start:
            del alphabet[3]
        if session_args.remove_end:
            del alphabet[4]
    elif session_args.grammar[:4] == "dyck":
        # remove PAD
        del alphabet[0]
        # remove UNK
        del alphabet[3] 
        if session_args.remove_start:
            del alphabet[1]
        if session_args.remove_end:
            del alphabet[2]

    #================================================================================================================================
    # Extract DFA from model 
    #================================================================================================================================
    print("alphabet : ", alphabet)

    dfa_list, _ = extract_dfas(logger, session_args.dfa_min_state, session_args.dfa_max_state, session_args.dfa_state_step, 
                            session_args.cluster_method, x.numpy(), x_len.numpy(), hidden.numpy(), y.numpy(), 
                            input_dim = session_args.vocab_size, input_range = list(alphabet.keys()), alphabet = alphabet, normalize_states = session_args.normalize_states, 
                            pos_data_only=session_args.pos_data_only, tsne = session_args.tsne, tsne_fname = session_args.tsne_fname)
     
    #================================================================================================================================
    # Select best DFA and save
    #================================================================================================================================

    val_x, val_x_len, val_hidden, val_mask, val_y= get_hidden_states(model, val_data, device, remove_start = session_args.remove_start, remove_end=session_args.remove_end)
    

    best_dfa, best_k = get_best_dfa(logger, dfa_list, val_x.numpy(), val_x_len.numpy(), val_y.numpy(), alphabet)
    dfa_path = "../dfa/" + suffix + ".dfa"
    save_dfa(best_dfa.dfa, dfa_path)
    
    #print(best_dfa)

    # dfa = load_dfa(dfa_path)
    # create_diagraph(dfa, alphabet, suffix)
