import argparse
import sys
import torch as th
from copy import copy
from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage

from utils.configuration import Configuration
from model.scripts import evaluation, evaluation_clevrer2, evaluation_clevrer_blackout
from data.datasets.ADEPT.dataset import AdeptDataset

def eval(model: str, cfg: str,  dataset: str, device: int = 0, n: int = 100, num_gpus: int = 1, objects: bool = True, nice: bool = False, individual: bool = False, trace: bool = False, concept: str = '', mota: bool = True, plot_frequency: int = 1, plot_first_samples: int = 2):
    
    evalset = None
    level = None
    
    if dataset == 'adept':
        root_path = f'out/final_adept/'
    else:
        root_path = f'out/final_blackout/'
    cfg = root_path + 'cfg/cfg_' + cfg + '.json'
    load = root_path + 'nets/net_' + model + '.pt'

    # config
    cfg = Configuration(cfg)
    dataset_name = cfg.dataset

    if device >= 0:
        cfg.device = device
        cfg.model.device = cfg.device
        cfg.model_path = f"{cfg.model_path}.device{cfg.device}"

    if n >= 0:
        cfg.model_path = f"{cfg.model_path}.run{n}"

    # gpu
    num_gpus = th.cuda.device_count()
    if cfg.device >= num_gpus:
        cfg.device = num_gpus - 1
        cfg.model.device = cfg.device

    if num_gpus > 0:
        num_gpus = num_gpus

    # info
    print(f'Using device {cfg.device} Cuda aivalable: {th.cuda.is_available()} Cuda count: {th.cuda.device_count()}')
    print(f'Using {num_gpus} GPU{"s" if num_gpus > 1 else ""}')
    print(f'Evaluating model {cfg.model_path}')

    if evalset is None or cfg.model.level != level:
        level = cfg.model.level
        if dataset == 'adept':
            evalset = AdeptDataset("./", dataset_name, concept,  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        elif dataset == 'clevrer':
            evalset  = ClevrerDataset("./", cfg.dataset, concept, (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True, evaluation=True)

    if  dataset == 'adept':
        evaluation.save(cfg, evalset, load, (cfg.model.level*2), cfg.model.input_size, objects, nice, individual, trace, concept, mota, plot_frequency=plot_frequency, plot_first_samples=plot_first_samples)
    elif dataset == 'clevrer':
        evaluation_clevrer_blackout.save(cfg, evalset, load, (cfg.model.level*2), cfg.model.input_size, objects, nice, individual, trace, concept, mota, plot_frequency=plot_frequency, plot_first_samples=plot_first_samples)
    

if __name__ == "__main__":

    # use argumenets to set if training or testing
    parser = argparse.ArgumentParser()
    parser.add_argument('-model', default="", type=str)
    parser.add_argument('-cfg', default="", type=str)
    parser.add_argument('-dataset', default="", type=str)
    args = parser.parse_args(sys.argv[1:])

    if args.dataset == 'adept':
        #eval(model = args.model, cfg=args.cfg, dataset = args.dataset, plot_frequency=2, plot_first_samples=3, concept='test')
        eval(model = args.model, cfg=args.cfg, dataset = args.dataset, plot_frequency=2, plot_first_samples=3, concept='createdown')
    elif args.dataset == "clevrer":
        eval(model = args.model, cfg=args.cfg, dataset = args.dataset, plot_frequency=2, plot_first_samples=3, concept='val')
        eval(model = args.model, cfg=args.cfg, dataset = args.dataset, plot_frequency=2, plot_first_samples=0, concept='val')    
    else:
        print("Dataset not supported")