import argparse
import sys
import torch as th
from copy import copy


from utils.configuration import Configuration
from model.scripts import training, evaluation, evaluation_cater
#from data.datasets.moving_mnist.dataset import MovingMNISTDataset
#from data.datasets.video.dataset import VideoDataset, MultipleVideosDataset, TimeLossVideoDataset
from data.datasets.CATER.dataset import CaterDataset
from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage
from data.datasets.ADEPT.dataset import AdeptDataset
#from data.datasets.SPACE.dataset import SpaceDataset
#from data.datasets.MOVIE.dataset import MoviEDataset
#from data.datasets.PhysicalConcepts.dataset import PhysicalConceptsDataset

CFG_PATH = "cfg.json"

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-cfg", default=CFG_PATH)
    parser.add_argument("-num-gpus", default=1, type=int)
    parser.add_argument("-n", default=-1, type=int)
    parser.add_argument("-load", default="", type=str)
    parser.add_argument("-dataset-file", default="", type=str)
    parser.add_argument("-device", default=0, type=int)
    parser.add_argument("-testset", action="store_true")
    mode_group = parser.add_mutually_exclusive_group(required=True)
    mode_group.add_argument("-train", action="store_true")
    mode_group.add_argument("-eval", action="store_true")
    mode_group.add_argument("-save", action="store_true")
    mode_group.add_argument("-export", action="store_true")
    mode_group.add_argument("-surprise", action="store_true")
    parser.add_argument("-objects", action="store_true")
    parser.add_argument("-nice", action="store_true")
    parser.add_argument("-individual", action="store_true")
    parser.add_argument("-trace", action="store_true")
    parser.add_argument("-concept", default="", type=str)
    parser.add_argument("-mota", action="store_true")

    args = parser.parse_args(sys.argv[1:])

    #if not args.objects and not args.nice and not args.individual:
    #    args.objects = True

    cfg = Configuration(args.cfg)

    if args.device >= 0:
        cfg.device = args.device
        cfg.model.device = cfg.device
        cfg.model_path = f"{cfg.model_path}.device{cfg.device}"

    if args.n >= 0:
        cfg.model_path = f"{cfg.model_path}.run{args.n}"

    num_gpus = th.cuda.device_count()
    
    if cfg.device >= num_gpus:
        cfg.device = num_gpus - 1
        cfg.model.device = cfg.device

    if args.num_gpus > 0:
        num_gpus = args.num_gpus

    print(f'Using device {cfg.device} Cuda aivalable: {th.cuda.is_available()} Cuda count: {th.cuda.device_count()}')
    print(f'Using {num_gpus} GPU{"s" if num_gpus > 1 else ""}')
    print(f'{"Training" if args.train else "Evaluating"} model {cfg.model_path}')

    trainset = None
    valset   = None
    testset  = None
    if cfg.datatype == "moving_mnist":
        trainset = MovingMNISTDataset(2, 'train', 64, 64, cfg.sequence_len)
        valset   = MovingMNISTDataset(2, 'train', 64, 64, cfg.sequence_len)
        testset  = MovingMNISTDataset(2, 'test', 64, 64, cfg.sequence_len)

    if cfg.datatype == "video":
        trainset = VideoDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1)
        valset   = VideoDataset("./", cfg.dataset, "test",  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1)
        testset  = VideoDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1)
        cfg.sequence_len += 1

    if cfg.datatype == "physical-concepts":
        trainset = PhysicalConceptsDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        valset   = PhysicalConceptsDataset("./", cfg.dataset, "test",   (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        testset  = PhysicalConceptsDataset("./", cfg.dataset, "test",  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        cfg.sequence_len += 1

    if cfg.datatype == "multiple-videos":
        trainset = MultipleVideosDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1)
        valset   = MultipleVideosDataset("./", cfg.dataset, "test",  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1)
        testset  = MultipleVideosDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1)
        cfg.sequence_len += 1

    if cfg.datatype == "time-loss-video":
        trainset = TimeLossVideoDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1)
        valset   = TimeLossVideoDataset("./", cfg.dataset, "test",  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1)
        testset  = TimeLossVideoDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), cfg.sequence_len + 1)

    if cfg.datatype == "clevrer":
        if args.train:
            trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False)
        valset   = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True)
        testset  = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True)
        cfg.sequence_len += 1

    if cfg.datatype == "cater":
        trainset = CaterDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        valset   = CaterDataset("./", cfg.dataset, "val",   (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        testset  = CaterDataset("./", cfg.dataset, "test",  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        cfg.sequence_len += 1

    if cfg.datatype == "adept":
        if args.surprise:
            dataset = AdeptDataset("./", cfg.dataset, args.concept,  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        elif args.save:
            if args.concept == "":
                testset  = AdeptDataset("./", cfg.dataset, "test",  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
            else:
                testset  = AdeptDataset("./", cfg.dataset, args.concept,  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        else:
            trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
            valset   = AdeptDataset("./", cfg.dataset, "test",  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
            testset  = AdeptDataset("./", cfg.dataset, "test",  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))

    if cfg.datatype == "space":
        trainset = SpaceDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        valset   = SpaceDataset("./", cfg.dataset, "val",   (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        testset  = SpaceDataset("./", cfg.dataset, "test",  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))

    if cfg.datatype == "movi-e":
        trainset = MoviEDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        valset   = MoviEDataset("./", cfg.dataset, "val",   (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
        testset  = MoviEDataset("./", cfg.dataset, "test",  (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))

    if cfg.datatype == "latent-cater":
        if args.dataset_file != "":
            cfg.dataset = args.dataset_file
        trainset = CaterLatentDataset("./", cfg.dataset, "train")
        valset   = CaterLatentDataset("./", cfg.dataset, "val")
        testset  = CaterLatentDataset("./", cfg.dataset, "test")

    if cfg.datatype == "latent-cater":
        if cfg.latent_type == "snitch_tracker":
            training.train_latent_tracker(cfg, trainset, valset, testset, args.load)
        elif cfg.latent_type == "object_behavior":
            training.train_latent_action_classifier(cfg, trainset, testset, args.load)
    elif args.train:
        training.run(cfg, num_gpus, trainset, valset, testset, args.load, (cfg.model.level*2))
    elif args.save and cfg.datatype == "cater":
        evaluation_cater.save(cfg, testset if args.testset else trainset, args.load, (cfg.model.level*2), cfg.model.input_size, args.objects, args.nice, args.individual)
    elif args.save:
        evaluation.save(cfg, testset, args.load, (cfg.model.level*2), cfg.model.input_size, args.objects, args.nice, args.individual, args.trace, args.concept, args.mota)
    elif args.export:
        evaluation.export_dataset(cfg, trainset, testset, args.load, f"{args.load}.latent-states")
