import os
import time
import torch
import argparse

import numpy as np
import pandas as pd

from tqdm import tqdm
from pathlib import Path    
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import StepLR
from datasets import load_from_disk
from sklearn.metrics import accuracy_score
from datasets import load_dataset, concatenate_datasets

from model import RNN
from constants import all_classes, ZTF_passband_to_wavelengths

# <----- Defaults for training the models ----->
default_num_epochs = 1000
default_batch_size = 128
default_learning_rate = 2e-3
default_model_dir = Path('./models/test_model')

max_seq_length = 200
bands = ['g','r']

max_n_per_class = 2000

ts_dim = 4
n_classes = 7

flag_value = -9


# Switch device to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
print(f"Using {device} device")

def parse_args():
    '''
    Get commandline options
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_epochs', type=int, default=default_num_epochs, help='Number of epochs to train the model for.')
    parser.add_argument('--batch_size', type=int, default=default_batch_size, help='Batch size used for training.')
    parser.add_argument('--lr', type=float, default=default_learning_rate, help='Learning rate used for training.')
    parser.add_argument('--dir', type=Path, default=default_model_dir, help='Directory for saving the models and best model during training.')

    args = parser.parse_args()
    return args

def get_label_encoding(source_class):

    assert source_class in all_classes
    encoding = all_classes.tolist().index(source_class)
    return encoding

def save_args_to_csv(args, filepath):

    df = pd.DataFrame([vars(args)])  # Wrap in list to make a single-row DataFrame
    df.to_csv(filepath, index=False)    

def filter_missing_or_short_lc(datapoint, bands):
    if 'g' in bands:
        if datapoint['bands_data']['g'] is None:
            return False
        if len(datapoint['bands_data']['g']) == 1:
            return False
    if 'r' in bands:
        if datapoint['bands_data']['r'] is None:
            return False
        if len(datapoint['bands_data']['r']) == 1:
            return False
    return True

def custom_collate(batch):

    lcs = []
    lengths = []
    source_ids = []
    labels = []

    for sample in batch:

        band_lcs = []
        source_id = sample['sourceid']
        source_class = int(sample['class_str'])
        
        for band in sample['bands_data']:
            
            # This  could be r, g, or i bands
            band_data = sample['bands_data'][band]
            if band != 'i' and band_data != None:

                band_df = pd.DataFrame()
                band_df['MJD'] = np.array(band_data['mjd'])
                band_df['MAG'] = np.array(band_data['target'])
                band_df['MAG_ERR'] = band_data['past_feat_dynamic_real']
                band_df['MEAN_LAM'] = [ZTF_passband_to_wavelengths[band]] * len(band_df['MJD'])

                band_lcs.append(band_df)
        
        full_lc = pd.concat(band_lcs)   
        full_lc = full_lc.sort_values(by='MJD')
        
        # Substract out mjd of the first day
        full_lc['MJD'] = full_lc['MJD'] - min(full_lc['MJD'])
        full_lc = full_lc.to_numpy(dtype=np.float32)

        if full_lc.shape[0] > max_seq_length:

            full_lc = full_lc[:max_seq_length,:]

        full_lc = torch.from_numpy(full_lc)

        lcs.append(full_lc)
        lengths.append(full_lc.shape[0])
        source_ids.append(source_id)
        labels.append(get_label_encoding(source_class))
        
    
    ts_tensor = pad_sequence(lcs, batch_first=True, padding_value=flag_value)
    lengths = torch.from_numpy(np.array(lengths, dtype=np.long))
    source_ids = torch.from_numpy(np.array(source_ids, dtype=np.float32))
    labels = torch.from_numpy(np.array(labels, dtype=np.long))

    d = {
        'ts': ts_tensor,
        'length': lengths,
        'source_id': source_ids,
        'label': labels,
    }
    
    return d


def run_training_loop(args):

    # Assign the arguments to variables
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    lr = args.lr
    model_dir = args.dir
    
    # Create a dir to save the model files    
    model_dir.mkdir(parents=True, exist_ok=True)

    generator = torch.Generator(device=device)

    # Load the dataset and create the dataloader
    dataset = load_from_disk("hf_csdr1_multiband_raw_lc_subclass_class_str_v2", keep_in_memory=True)

    train_set = dataset['train'].filter(lambda x: filter_missing_or_short_lc(x, bands))
    val_set = dataset['validation'].filter(lambda x: filter_missing_or_short_lc(x, bands))

    save_args_to_csv(args, f'{model_dir}/train_args.csv')

    # train_sets_by_class = []

    # # Balance the dataset by limiting the number of samples
    # for c in all_classes:

    #     filtered_dataset = train_set.filter(lambda x: x['class_str'] == str(c))
    #     limited = filtered_dataset.select(range(min(len(filtered_dataset), max_n_per_class)))
    #     train_sets_by_class.append(limited)

    # train_set = concatenate_datasets(train_sets_by_class)

    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=custom_collate, generator=generator)
    val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=True, collate_fn=custom_collate, generator=generator)

    # Calculate the class weights to adress class imbalance problem
    _, counts = np.unique([int(x) for x in train_set['class_str']], return_counts=True)
    total = len(train_set['class_str'])
    weights = torch.from_numpy(total/(len(all_classes)*counts)).float().to(device)

    model = RNN(ts_dim, n_classes).to(device)
    #model.load_state_dict(torch.load(f'models/test_model_seq500/best_model.pth', map_location=torch.device(device)))
    model.train()

    loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=20, gamma=0.1)

    avg_train_losses = []
    avg_val_losses = []
    all_val_acc = []

    # Training loop
    for epoch in range(num_epochs):

        print(f"Epoch {epoch+1}/{num_epochs} started")
        start_time = time.time()

        # TRAINING
        model.train()
        train_loss_values = []

        # Loop over all the batches in the data set
        for i, batch in enumerate(tqdm(train_dataloader, desc='Training Epoch')):

            # Move everything to the device
            batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}   

            # Forward pass
            probs = model(batch)
            loss = loss_fn(probs, batch['label'])

            train_loss_values.append(loss.item())

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # TODO: Add validation loop here and log the value of the loss
        avg_train_loss = np.mean(train_loss_values)
        avg_train_losses.append(avg_train_loss)

        # VALIDATION
        
        model.eval()
        val_loss_values = []

        # Arrays to store the preds and true values
        all_preds = []
        all_labels = []

        with torch.no_grad():
            # Loop over all the batches in the data set
            for i, batch in enumerate(tqdm(val_dataloader, desc='Validation Epoch')):

                # Move everything to the device
                batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}   

                # Forward pass
                probs = model(batch)
                loss = loss_fn(probs, batch['label'])

                val_loss_values.append(loss.item())

                probs = torch.softmax(probs, dim=1)  
                preds = probs.argmax(dim=1).cpu().numpy()
                labels = batch['label'].cpu().numpy()

                all_preds.append(preds)
                all_labels.append(labels)

            all_preds = np.concatenate(all_preds)
            all_labels = np.concatenate(all_labels)

            acc = accuracy_score(all_labels, all_preds)
            all_val_acc.append(acc)

            avg_val_loss = np.mean(val_loss_values)
            avg_val_losses.append(avg_val_loss)

        scheduler.step()  

        print(f"Avg training loss: {float(avg_train_loss):.4f}")
        print(f"Avg validation loss: {float(avg_val_loss):.4f}")
        print(f"Validation accurarcy: {float(acc):.4f}")

        if np.isnan(avg_train_loss) == True:
            print("Training loss was nan. Exiting the loop.")
            break

        # TODO: Save the model. We can load the best model based on the validation acc for inference.
        if acc == np.max(all_val_acc):
            print("Saving model")
            torch.save(model.state_dict(), f'{model_dir}/best_model.pth')
        print(f"Time taken: {time.time() - start_time:.2f}s\n=======\n")

def main():

    args = parse_args()
    run_training_loop(args)


if __name__=='__main__':
    main()