import torch
import sys
sys.path.append('../../code/')
from LSTM import LSTMClassifier
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import argparse
import numpy as np
from torch import nn
from sklearn.metrics import confusion_matrix, hamming_loss
sys.path.append('../../../')
from utils import evaluate_dist, load_data

def compute_persample_hamming(preds, labels, lens):
    hammings = []
    for i in range(len(preds)):
        pred = preds[i][:lens[i]]
        label = labels[i][:lens[i]]
        hammings.append(hamming_loss(pred, label))
    return np.array(hammings)

def get_loss(dataloader, model, criterion):
    model.eval()
    total_loss = 0
    total_samples = 0
    with torch.no_grad():
        for batch in dataloader:
            inputs, labels = batch
            labels = labels.long()
            outputs = model(inputs.float())
            outputs = outputs.permute(0, 2, 1)  # change shape to (batch_size, num_classes, sequence_length)
            loss = criterion(outputs, labels).item()

            total_loss += loss
            total_samples += len(inputs)
    average_loss = total_loss / total_samples
    return average_loss

def train(model, train_data, train_labels, train_lens, val_data, val_labels, val_lens, test_data, test_labels, test_lens, num_epochs, lr):
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_dataloader = DataLoader(list(zip(train_data, train_labels)), batch_size=32, shuffle=True)
    val_dataloader = DataLoader(list(zip(val_data, val_labels)), batch_size=32, shuffle=True)
    test_dataloader = DataLoader(list(zip(test_data, test_labels)), batch_size=32, shuffle=True)

    val_losses = []
    val_hammings = []
    test_losses = []
    test_hammings = []

    # Training loop
    for epoch in range(num_epochs):
        for batch in train_dataloader:
            inputs, labels = batch
            labels = labels.long()

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs.float())
            outputs = outputs.permute(0, 2, 1)  # change shape to (batch_size, num_classes, sequence_length)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()


        val_preds = model(torch.Tensor(val_data).float()).detach().numpy()
        val_preds = np.argmax(val_preds, axis=-1)
        train_preds = model(torch.Tensor(train_data).float()).detach().numpy()
        train_preds = np.argmax(train_preds, axis=-1)
        val_hamming = hamming_loss(val_preds.reshape(-1,), val_labels.reshape(-1,))
        test_loss = get_loss(test_dataloader, model, criterion)
        val_loss = get_loss(val_dataloader, model, criterion)
        test_losses.append(test_loss)
        val_losses.append(val_loss)
        print(f'Epoch [{epoch + 1}/{num_epochs}], \
            Loss: {loss.item():.4f}, \
            Val Hamming: {val_hamming:.4f}, \
            Train Hamming: {hamming_loss(train_preds.reshape(-1,), train_labels.reshape(-1,)):.4f}')
    
    test_preds = model(torch.Tensor(test_data).float()).detach().numpy()
    test_preds = np.argmax(test_preds, axis=-1)
    test_hamming = compute_persample_hamming(preds=test_preds, labels=test_labels, lens=test_lens)
    print(f'Test Hamming: {np.mean(test_hamming):.4f}')
    np.savez('./%s/%s/checkpoint_'%(args.data, args.short_name) + str(args.cv) + '.npz', loglik=[-1*loss for loss in test_losses], val_loglik=[-1*loss for loss in val_losses], 
            validation_hammings=[val_hamming], test_hammings=[test_hamming]);
    

def main(args, data_load_config):
    print('Selecting data...')
    x_train, z_train, train_lens, x_valid, z_valid, valid_lens, x_test, z_test, test_lens = load_data(data_type=args.data, config=data_load_config, normalize=False, pad_ragged=True, path_to_data='../../../data/')
    if np.sort(np.unique(z_train)) != np.arange(z_train.max()+1):
        print('Discontinuity found in lables. Remapping..')
        label_mapping = {old_label: new_index for new_index, old_label in enumerate(np.sort(np.unique(z_train)))}

        def remap_labels(labels, mapping):
            remapped_labels = np.vectorize(mapping.get)(labels)
            return remapped_labels
        
        z_train = remap_labels(z_train, label_mapping)
        z_valid = remap_labels(z_valid, label_mapping)
        z_test = remap_labels(z_test, label_mapping)

 
    print('Training model...')
    D = x_train[0].shape[1]
    model = LSTMClassifier(input_dim=D, 
                           hidden_dim=2*D, 
                           num_classes=len(np.unique(z_train)), 
                           num_layers=args.num_layers, 
                           dropout=args.dropout).float()
    train(model=model, 
          train_data=x_train, 
          train_labels=z_train, 
          train_lens=train_lens,
          val_data=x_valid, 
          val_labels=z_valid, 
          val_lens=valid_lens,
          test_data=x_test, 
          test_labels=z_test, 
          test_lens=test_lens,
          num_epochs=args.epochs, 
          lr=args.lr)
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run LSTM')
    parser.add_argument('--data', type=str, default='sim_hard')

    # Training hyperparams
    parser.add_argument('--epochs', type=int)
    parser.add_argument('--lr', type=float)

    # Model hyperparams
    parser.add_argument('--num_layers', type=int)
    parser.add_argument('--dropout', type=float)
    
    parser.add_argument('--cv', type=int)
    parser.add_argument('--ds_factor', type=int)
    
    parser.add_argument('--short_name', type=str)
    args = parser.parse_args()

    data_load_config = {'sim_easy': {'n_train': 100, 'n_valid': 50, 'n_test': 50}, 
                        'sim_hard': {'n_train': 100, 'n_valid': 50, 'n_test': 50},
                        'sim_semi_markov': {'n_train': 60, 'n_valid': 20, 'n_test': 20},
                        'har': {'ds_factor': args.ds_factor},
                        'har_70': {'ds_factor': args.ds_factor},
                        }      
    main(args, data_load_config)
