import sys
from torch.utils.data import DataLoader
from cont_data import *
from common import parse_args
import numpy as np
from utils import *
from training import train
from net import Net
from datetime import datetime

if __name__ == '__main__':
    device = "cuda" if torch.cuda.is_available() else "cpu"

    args = parse_args()
    args.logger = Logger(args, args.folder)
    args.logger.now()

    if args.dynamic is not None:
        args.n_components = args.dynamic

    np.random.seed(args.seed)
    args.device = device

    train_data, test_data = get_data(args)

    if args.task_type == 'standardCL_randomcls':
        task_list = generate_random_cl(args)
        train_data = StandardCL(train_data, args, task_list)
        test_data = StandardCL(test_data, args, task_list)

    args.sup_labels = []
    for task in task_list:
        args.logger.print(task)
        for name in task[0]:
            if name not in args.sup_labels:
                args.sup_labels.append(name)

    args.logger.print('\n\n', sys.argv, '\n\n')

    # number of heads after final task
    args.out_size = len(args.sup_labels)
    args.logger.print('\n', args, '\n')

    model_clip, _ = clip.load('ViT-B/32', args.device)
    args.model_clip = model_clip

    if args.model == 'maha_ipca':
        from maha_ipca import PLS as Model
        args.net = Net(args.in_dim, args.out_dim, bias=True).to(args.device)
    elif args.model == 'batch_pca':
        from batch_pca import PLS as Model
        args.net = Net(args.in_dim, args.out_dim, bias=True).to(args.device)
    else:
        NotImplementedError
    args.criterion = Criterion(args, args.net)
    model = Model(args)

    train(task_list, args, train_data, test_data, model_clip, model)

    args.logger.now()
