import os, signal, sys, time
import pickle

import hydra
import wandb
from omegaconf import DictConfig, OmegaConf

import torch

from oucl.utils import set_seed, smart_dir
from oucl.scenarios.datasets import load_dataset
from oucl.scenarios.scenarios import load_scenario
from oucl.agents.agents import load_agent


from tqdm import tqdm

def clean_exit(a, b):
    if os.getpid() == main_pid:
        wandb.finish()

    torch.cuda.empty_cache() 
    sys.exit(0)
    

@hydra.main(version_base=None, config_path='configs', config_name='main')
def main(config: DictConfig) -> None:

    set_seed(config.seed)
    signal.signal(signal.SIGINT, clean_exit)
    print(os.getpid())

    wandb.login(key=config.key)

    wandb.init(
        project=config.project,
        name=config.log,
        notes="",
        tags=["development"],
        config=OmegaConf.to_container(config, resolve=True)
    )

    overall_results = {}

    agent = load_agent(config)
    dataset = load_dataset(config.dataset)

    stream, evaluator = load_scenario(dataset, config)
    stream.eval_iters[-1] = len(stream) - 1
    current_task = 0

    total_time = 0
    start_time = time.time()
    for it, (data, y, inds, o) in enumerate(tqdm(stream)):
        #print(y.min(), y.max())
        if it in stream.task_boundaries:
            current_task += 1
            print(f'New Task - {current_task}')
        
        agent(data, y, current_task, inds)
        
        
        if it in stream.eval_iters:
            total_time += time.time() - start_time
            class_res, clust_res = evaluator.evaluate(agent)
            overall_results[evaluator.step] = {
                'classification': class_res,
                'clustering': clust_res
            } 
            start_time = time.time()

    total_time += time.time() - start_time

    wandb.finish()
    overall_results['time'] = total_time
    with open(smart_dir(f'results/{config.log}/') + 'results.pkl', 'wb') as fp:
        pickle.dump(overall_results, fp)

    with open(smart_dir(f'results/{config.log}/') + 'model_final.pth', 'wb') as fp:
        torch.save(agent.encoder.backbone.state_dict(), fp)


    if hasattr(agent, 'memory'):
        if hasattr(agent.memory, 'shutdown'):
            agent.memory.shutdown()
    elif hasattr(agent, 'ltm'):
        if hasattr(agent.ltm, 'shutdown'):
            agent.ltm.shutdown()
    elif hasattr(agent, 'stm'):
        if hasattr(agent.stm, 'shutdown'):
            agent.stm.shutdown()

    

if __name__ == '__main__':
    main_pid = os.getpid()
    main()