import os
import time
import torch
import argparse
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from pathlib import Path    
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, ConfusionMatrixDisplay

from model import RNN
from constants import all_classes, class_map, ZTF_passband_to_wavelengths
from train import custom_collate, get_label_encoding, filter_missing_or_short_lc, bands

# <----- Defaults for training the models ----->
default_num_epochs = 100
default_batch_size = 1024
default_learning_rate = 1e-5
default_model_dir = Path('./models/test_model')

ts_dim = 4
n_classes = 7

flag_value = -9

# Switch device to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
print(f"Using {device} device")

def parse_args():
    '''
    Get commandline options
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('--dir', type=Path, default=default_model_dir, help='Directory for saving the models and best model during training.')

    args = parser.parse_args()
    return args


def run_testing_loop(args):

    # Assign the arguments to variables
    batch_size = default_batch_size
    model_dir = args.dir

    generator = torch.Generator(device=device)

    # Load the dataset and create the dataloader
    dataset = load_from_disk("hf_csdr1_multiband_raw_lc_subclass_class_str_v2", keep_in_memory=True)
    test_dataloader = DataLoader(dataset['test'].filter(lambda x: filter_missing_or_short_lc(x, bands)), batch_size=batch_size, shuffle=False, collate_fn=custom_collate, generator=generator)


    model = RNN(ts_dim, n_classes).to(device)
    model.load_state_dict(torch.load(f'{model_dir}/best_model.pth', map_location=torch.device(device)))
    model.eval()

    all_preds = []
    all_labels = []

    # inference
    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Testing"):
            # move to device
            batch = {k: v.to(device) if torch.is_tensor(v) else v 
                     for k, v in batch.items()}

            logits = model(batch)                   
            probs = torch.softmax(logits, dim=1)  

            preds = logits.argmax(dim=1).cpu().numpy()
            labels = batch['label'].cpu().numpy()

            all_preds.append(preds)
            all_labels.append(labels)


    # flatten lists
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    # ── 3) Metrics ──────────────────────────────────────────────────────────────
    target_names  = [class_map[x] for x in all_classes]
    acc = accuracy_score(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, target_names=target_names)
    cm = confusion_matrix(all_labels, all_preds,normalize='true')

    cm  = confusion_matrix(labels, preds, labels=list(range(len((target_names)))))
    cmn = cm.astype(float) / cm.sum(1, keepdims=True)
    with open(os.path.join(f"{model_dir}/confusion_data.pkl"), "wb") as f:
        pickle.dump(cmn, f)

    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=target_names)
    disp.plot()
    plt.savefig(f'{model_dir}/cf.png')
    plt.savefig(f'{model_dir}/cf.pdf')



    # ── 4) Print results ────────────────────────────────────────────────────────
    print(f"\nTest Accuracy: {acc*100:.2f}%")
    print("Classification Report:")
    print(report)
    print("Confusion Matrix:")
    print(cm)



def main():

    args = parse_args()
    run_testing_loop(args)


if __name__=='__main__':
    main()