from lib.config import cfg, cfg_from_file
from lib.utils import *
from data.dataset import *
from lib.train import train_and_evaluate
from models.model_register import get_model
from lib.random_projection import SimpleVitNet


import numpy as np
import torch
import random
import argparse
import pprint
import random

def learn_continually():
    log('\nRunning experiments using PROTEUS')
    tasks = range(cfg.continual.n_tasks)
    if cfg.continual.shuffle_task:
        tasks = torch.randperm(cfg.continual.n_tasks).tolist()

    if cfg.run_label == "VTAB5T-Sim50":
        data_loader = torch.load(cfg.dtask.data_path + "/dataloader_output.pt", weights_only=False)
        class_mask = np.load(cfg.dtask.data_path + "/class_mask.npy", allow_pickle=True).tolist()
        cfg.dtask.n_components = [70 , 100, 100, 20]
        # In case that our labels aren't sorted 
        mapping_classes = dict()
        for i in range(len(class_mask)):
            for j in range(len(class_mask[i])):
                if i == 0:
                    mapping_classes[class_mask[i][j]] = j
                else:
                    mapping_classes[class_mask[i][j]] = (mapping_classes[class_mask[i-1][-1]] + 1) + j

    else:
        data_loader, class_mask = build_continual_dataloader(batch_size=cfg.dtask.batch_size)
        cfg.dtask.n_components = [cfg.dtask.n_components]*cfg.continual.n_tasks
        mapping_classes = None

    log('class_mask %s ' % class_mask)
    log('Number of classes %d ' % cfg.dtask.nb_classes)

    # Define a ViT model
    model = get_model()
    if cfg.dtask.freeze:
        for n, p in model.named_parameters():
            if ('_lora_' in n) or ('head' in n):
                p.requires_grad = True
                print(n)
            else:
                p.requires_grad = False

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log('number of params: %d' % n_parameters)

    # Define a RP model
    network = SimpleVitNet()
    criterion = torch.nn.CrossEntropyLoss().to(cfg.device)
    acc_matrix = np.zeros((cfg.continual.n_tasks, cfg.continual.n_tasks))
    reg_weight = cfg.dtask.reg_weight
    ortho_weight = cfg.dtask.ortho_weight
    M = 10000
    Q=torch.zeros(M, cfg.dtask.nb_classes)
    G=torch.zeros(M,M)

    train_and_evaluate(tasks, model, criterion, data_loader, cfg.device, class_mask, acc_matrix, 
                       reg_weight=reg_weight, added_units=cfg.dtask.added_units, alpha=cfg.dtask.alpha,
                       network=network, M=M, Q=Q, G=G, mapping_classes=mapping_classes, ortho_weight=ortho_weight)
    
def set_seed(seed):
    cfg.seed = seed
    torch.cuda.manual_seed_all(cfg.seed)
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

def main():
    parser = argparse.ArgumentParser(description='PROTEUS in Continual Learning')
    parser.add_argument('--cfg', dest='cfg_file', default='./config/cifar-100.yml')

    args = parser.parse_args()

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)

    log(pprint.pformat(cfg))

    gpu_list = cfg.gpu_ids.split(',')
    gpus = [int(iter) for iter in gpu_list]
    cfg.device = torch.device('cuda:' + str(gpus[0]))

    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    set_seed(cfg.seed)

    if cfg.continual.method.run_merlin:
        learn_continually()


if __name__ == '__main__':
    main()