
import numpy as np
import torch
import torchvision
import time
import argparse

import h5py
    
from mercat import mercat

import signal

### Signal handling for SLURM clusters reacting on external termination signal
class TerminationSignalError(Exception):
    """
    Indicates that this process is forced to terminate by a UNIX signal.
    """

    def __init__(self):
        super().__init__("The process is forced to terminate by an external signal!")


def __handle_signal(signum, frame):
    raise TerminationSignalError()


def raise_on_termination():
    """
    Makes this process catch SIGINT and SIGTERM.
    When the process receives such a signal after this call, a TerminationSignalError is raised.
    """

    signal.signal(signal.SIGINT, __handle_signal)
    signal.signal(signal.SIGTERM, __handle_signal)

## NOTE: ALL PATHS BELOW NEED TO BE CHANGED APPROPRIATELY! ANONYMIZED FOR REVIEW!
def main(args):
    
    raise_on_termination()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ## Run MNIST example
    if (args.data == 'mnist'):
        mnist = torch.flatten(torchvision.datasets.MNIST(root='data', train=True, download=True, transform=torchvision.transforms.ToTensor()).data, start_dim=1)
        mnist_digits = torchvision.datasets.MNIST(root='data', train=True, download=True, transform=torchvision.transforms.ToTensor()).targets
        ## filter even numbers
        mnist = mnist[(mnist_digits % 2) == 0,:]
        
        start = time.time()
        m_res = mercat(mnist, itermax=250, lr=0.01, stepsize=[100], batch_size=64, angle_subsample=64, shuffle=True, device=device)
        print(m_res[1,:])
        end = time.time()
        print("MNIST runtime: ")
        print(end - start)
        
    
    ## Run Tabula Sapiens human blood example
    if (args.data == 'tabula'):
        filename = "data/tabula_sapiens_blood.h5"

        with h5py.File(filename, "r") as f:
            tab_sap = np.asarray(f['tabula_sapiens']['exp'])
            print(tab_sap.shape)
            
        start = time.time()
        m_res = mercat(tab_sap, itermax=50, lr=0.01, stepsize=[10, 30], batch_size=64, angle_subsample=64, shuffle=True, device=device)
        print(m_res[1,:])
        end = time.time()
        print("Tabula Sapiens blood runtime: ")
        print(end - start)
    
    
    ## Run Murine Pancreas example
    if (args.data == 'pancreas'):
        filename = "data/murine_pancreas.h5"

        with h5py.File(filename, "r") as f:
            exp = np.asarray(f['panc']['exp'])
            print(exp.shape)
            
        start = time.time()
        m_res = mercat(exp, itermax=50, lr=0.01, stepsize=[10, 30], batch_size=64, angle_subsample=64, shuffle=True, device=device)
        print(m_res[1,:])
        end = time.time()
        print("Murine Pancreas runtime: ")
        print(end - start)
    
    
    ## Run Paul mice bone marrow example
    if (args.data == 'paul'):
        filename = "paul_hema.h5"

        with h5py.File(filename, "r") as f:
            exp = np.asarray(f['paul']['exp'])
            print(exp.shape)
            
        start = time.time()
        m_res = mercat(exp, itermax=200, lr=0.01, stepsize=[50, 150], batch_size=64, angle_subsample=64, shuffle=True, device=device)
        print(m_res[1,:])
        end = time.time()
        print("Bone Marrow runtime: ")
        print(end - start)
    

    ## Run HeLa cell cycle example
    if (args.data == 'hela'):
        filename = "data/hela_cc.h5"

        with h5py.File(filename, "r") as f:
            exp = np.asarray(f['cell_cycle']['exp'])
            print(exp.shape)
            
        start = time.time()
        m_res = mercat(exp, itermax=200, lr=0.01, stepsize=[50, 150], batch_size=64, angle_subsample=64, shuffle=True, device=device)
        print(m_res[1,:])
        end = time.time()
        print("Cell Cycle runtime: ")
        print(end - start)



if __name__ == '__main__':
    parser = argparse.ArgumentParser(
                    prog='Mercat GPU Experiments',
                    description='Real world experiments (clocking only) for Mercat on GPU')
    parser.add_argument('-d', '--data', choices=['mnist', 'tabula', 'pancreas', 'paul', 'hela'], required=True)
    args = parser.parse_args()
    main(args)