import torch
import sys
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import argparse
import numpy as np
from torch import nn
from sklearn.cluster import KMeans
from sklearn.metrics import hamming_loss
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

sys.path.append('../../code/')
from LSTM import LSTMForecaster
sys.path.append('../../../')
from utils import evaluate_dist, load_data

class RaggedDataset(Dataset):
    def __init__(self, data, data_lens):
        self.data = data
        self.data_lens = data_lens

    def __len__(self):
        return len(self.data)

    def __getitem__(self, ind):
        return self.data[ind], self.data_lens[ind]

def evaluate_loss(model, criterion, dataloader):
    model.eval()
    total_loss = 0
    total_samples = 0

    with torch.no_grad():
        for batch in dataloader:
            inputs, lens = batch
            targets = torch.clone(inputs)[:, 1:, :]
            inputs = inputs[:, :-1, :]
            outputs, _ = model(inputs.float())
            mask = torch.zeros_like(outputs, dtype=torch.bool)
            for i, l in enumerate(lens):
                mask[i, :l] = 1
            outputs = outputs * mask
            targets = targets * mask
            loss = criterion(outputs.float(), targets.float()).item()
            total_loss += loss * len(inputs)
            total_samples += len(inputs)

    average_loss = total_loss / total_samples
    return average_loss


def train(model, train_data, train_labels, train_lens, val_data, val_labels, val_lens, test_data, test_labels, test_lens, num_epochs, lr):
    batch_size=min(32, max(len(train_data)//4, 1))
    # Loss and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    n_classes = np.unique(train_labels).shape[0]
    print('n_classes: ', n_classes)
    print('unique train labels: ', np.unique(train_labels))
    print('unique val labels: ', np.unique(val_labels))
    print('unique test labels: ', np.unique(test_labels))
    train_dataset = RaggedDataset(train_data, train_lens)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataset = RaggedDataset(val_data, val_lens)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_dataset = RaggedDataset(test_data, test_lens)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


    early_stopping = 5
    val_hammings = []
    test_hammings = []
    test_losses = []
    val_losses = []
    # Training loop
    for epoch in range(num_epochs):
        epoch_train_losses = []
        model.train()
        for batch in train_dataloader:
            inputs, lens = batch
            targets = torch.clone(inputs)[:, 1:, :]
            inputs = inputs[:, :-1, :]

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs, _ = model(inputs) # outputs shape (batch_size, sequence_len, input_dim), hidden shape (num_layers, batch_size, hidden_dim)
            mask = torch.zeros_like(outputs, dtype=torch.bool)
            for i, l in enumerate(lens):
                mask[i, :l] = 1
            outputs = outputs * mask
            targets = targets * mask

            loss = criterion(outputs.float(), targets.float())
            
            if len(inputs) == batch_size:
                epoch_train_losses.append(loss.item())

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

        model.eval()
        with torch.no_grad():
            kmeans = KMeans(n_clusters=n_classes)
            train_out, _ = model(torch.Tensor(train_data).float())
            train_out = train_out.detach().numpy()
            train_state_preds = kmeans.fit(train_out.reshape(-1, train_out.shape[-1])).labels_.reshape(train_out.shape[0], train_out.shape[1])
            train_hamming, mapper = evaluate_dist(z_true=train_labels, z_pred=train_state_preds, z_lens=train_lens, k_max=n_classes)
            # train_hamming, state_mapper = evaluate_dist(z_true=z_train, z_pred=zt, z_lens=train_lens, k_max=L)
            kmeans = KMeans(n_clusters=n_classes)
            val_out, _ = model(torch.Tensor(val_data).float())
            val_out = val_out.detach().numpy()
            val_state_preds = kmeans.fit(val_out.reshape(-1, val_out.shape[-1])).labels_.reshape(val_out.shape[0], val_out.shape[1])
            val_hamming, _ = evaluate_dist(z_true=val_labels, z_pred=val_state_preds, z_lens=val_lens, k_max=n_classes, mapping=mapper)
            val_hammings.append(val_hamming)

            if len(val_losses) > early_stopping:
                if np.all(np.array(val_losses[-early_stopping:]) >= val_losses[-early_stopping]):
                    print('np.array(val_losses[-early_stopping:]): ', np.array(val_losses[-early_stopping:]))
                    print('val_losses[-early_stopping]: ', val_losses[-early_stopping])
                    print('np.array(val_losses[-early_stopping:]) >= val_losses[-early_stopping]: ', np.array(val_losses[-early_stopping:]) >= val_losses[-early_stopping])
                    print('Stopping early...')
                    break
            
            kmeans = KMeans(n_clusters=n_classes)
            test_out, _ = model(torch.Tensor(test_data).float())
            test_out = test_out.detach().numpy()
            test_state_preds = kmeans.fit(test_out.reshape(-1, test_out.shape[-1])).labels_.reshape(test_out.shape[0], test_out.shape[1])
            test_hamming, _ = evaluate_dist(z_true=test_labels, z_pred=test_state_preds, z_lens=test_lens, k_max=n_classes, mapping=mapper)
            test_hammings.append(test_hamming)

            val_loss = evaluate_loss(model, criterion, val_dataloader)
            val_losses.append(val_loss)
            test_loss = evaluate_loss(model, criterion, test_dataloader)
            test_losses.append(test_loss)

            print(f'Epoch [{epoch + 1}/{num_epochs}], \
                Loss: {np.mean(epoch_train_losses):.4f}, \
                Val Loss: {val_loss:.4f}, \
                Test Loss: {test_loss:.4f}, \
                Val Hamming: {np.mean(val_hamming):.4f}, \
                Train Hamming: {np.mean(train_hamming):.4f}, \
                Test Hamming: {np.mean(test_hamming):.4f}')
    
        if epoch % 10 == 0:
            np.savez('./%s/%s/checkpoint_'%(args.data, args.short_name) + str(args.cv) + '.npz', loglik=[-1*loss for loss in test_losses], val_loglik=[-1*loss for loss in val_losses], 
            validation_hammings=val_hammings, test_hammings=test_hammings);
    

def main(args, data_load_config):
    print('Selecting data...')
    x_train, z_train, train_lens, x_valid, z_valid, valid_lens, x_test, z_test, test_lens = load_data(data_type=args.data, config=data_load_config, normalize=False, pad_ragged=True, path_to_data='../../../data/')
    if type(x_train) == np.ndarray:
        x_train = torch.Tensor(x_train)
        x_valid = torch.Tensor(x_valid)
        x_test = torch.Tensor(x_test)
        z_train = torch.Tensor(z_train)
        z_valid = torch.Tensor(z_valid)
        z_test = torch.Tensor(z_test)
    
    if args.data == 'har_70':
        # Update classes so they will be 0, 1, 2, 3, 4, 5, 6, 7 rather than 1, 3, 4, 5, 6, 7, 8
        z_train[z_train==7] = 0
        z_valid[z_valid==7] = 0
        z_test[z_test==7] = 0
        z_train[z_train==8] = 2
        z_valid[z_valid==8] = 2
        z_test[z_test==8] = 2

    print('Training model...')
    D = x_train[0].shape[1]
    model = LSTMForecaster(input_dim=D, 
                           hidden_dim=2*D, 
                           num_layers=args.num_layers, 
                           dropout=args.dropout).float()
    train(model=model, 
          train_data=x_train, 
          train_labels=z_train, 
          train_lens=train_lens,
          val_data=x_valid, 
          val_labels=z_valid, 
          val_lens=valid_lens,
          test_data=x_test, 
          test_labels=z_test, 
          test_lens=test_lens,
          num_epochs=args.epochs, 
          lr=args.lr)
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run LSTM')
    parser.add_argument('--data', type=str, default='sim_hard')

    # Training hyperparams
    parser.add_argument('--epochs', type=int)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--ds_factor', type=int)

    # Model hyperparams
    parser.add_argument('--num_layers', type=int)
    parser.add_argument('--dropout', type=float)
    
    parser.add_argument('--cv', type=int)
    
    parser.add_argument('--short_name', type=str)
    args = parser.parse_args()

    data_load_config = {'sim_easy': {'n_train': 100, 'n_valid': 50, 'n_test': 50}, 
                        'sim_hard': {'n_train': 100, 'n_valid': 50, 'n_test': 50},
                        'sim_semi_markov': {'n_train': 60, 'n_valid': 20, 'n_test': 20},
                        'har': {'ds_factor': args.ds_factor},
                        'har_70': {'ds_factor': args.ds_factor},
                        }         
    main(args, data_load_config)
