import dataset
import os
from train import trainer
import model
import torch.utils.data as data
import torch.nn as nn
import torch
import torchvision.models as md

NUM_TASK = 10
FREQUENCEY = 1
DATASET = dataset.SplitCIFAR10
MODEL = trainer(model.fix_net(model.get_resnet18(),512,NUM_TASK).to('cuda'),0.01)

def getF(num_task,DATASET,model_test:trainer):
    features = []
    for i in range(num_task):
        feature = model_test.test_feature(data.DataLoader(DATASET(task=[i]),batch_size=32,shuffle=True,num_workers=0))
        feature = torch.cat(feature).mean(0).unsqueeze(0).transpose(0,1)
        feature = model_test.target_net.normalize(feature,1)
        features.append(feature)
    return torch.cat(features,1)

def test(num_task,DATASET,model_test,features):
    print(num_task)
    task=[i for i in range(num_task)]
    trainloader = data.DataLoader(DATASET(task=task),batch_size=64,shuffle=True,num_workers=4)
    testloader = data.DataLoader(DATASET(train=False,task=task),batch_size=64,shuffle=True,num_workers=4)
    model_test.target_net.fc.weight.data[:num_task]=features[:num_task]
    model_test.test(trainloader)
    model_test.test(testloader)

if __name__ == '__main__':
    features = getF(NUM_TASK,DATASET,MODEL).transpose(0,1)
    for i in range(NUM_TASK):
        if (i+1) % FREQUENCEY==0:
            test(i+1,DATASET,MODEL,features)