import dataset
import os
from train import trainer
import model
import torch.utils.data as data
import torch.nn as nn
import torch
import torchvision.models as md

NUM_TASK = 10
PRE_NUM = 2
FREQUENCEY = 1
DATASET = dataset.SplitMNIST
MODEL = trainer(model.fix_net(model.get_res_model(1,32),32,NUM_TASK).to('cuda'),0.01)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def getF(num_task,DATASET,model_test:trainer):
    features = []
    for i in range(num_task):
        feature = model_test.test_feature(data.DataLoader(DATASET(task=[i]),batch_size=32,shuffle=True,num_workers=0))
        feature = torch.cat(feature).mean(0).unsqueeze(0).transpose(0,1)
        #feature = model_test.target_net.normalize(feature)
        features.append(feature)
    return torch.cat(features,1)

def pre_train():
    task=[i for i in range(NUM_TASK)]
    for i in range(PRE_NUM):
        acc = MODEL.pre_train(data.DataLoader(DATASET(task=task),batch_size=64,shuffle=True,num_workers=4),i,PRE_NUM)
    MODEL.fixByName()
    

if __name__ == '__main__':
    pre_train()
    task=[i for i in range(NUM_TASK)]
    trainloader = data.DataLoader(DATASET(task=task),batch_size=64,shuffle=True,num_workers=4)
    testloader = data.DataLoader(DATASET(train=False,task=task),batch_size=64,shuffle=True,num_workers=4)
    fc = nn.Linear(32,NUM_TASK).to(device)
    opt = torch.optim.SGD(fc.parameters(),0.01)
    loss_f = nn.CrossEntropyLoss()
    features = getF(NUM_TASK,DATASET,MODEL).transpose(0,-1).to(device)
    f_mnist = features.transpose(0,-1).unsqueeze(0)
    targets = torch.tensor([i for i in range(NUM_TASK)]).to(device)
    for i in range(8):
        y = fc(features)
        loss = loss_f(y,targets)
        opt.zero_grad()
        loss.backward()
        opt.step()
        acc = (y.argmax(dim=-1) == targets.to(device)).float().mean()
        print(f" loss = {loss:.5f}, acc = {acc:.5f}")
        MODEL.target_net.fc=fc
        MODEL.test(trainloader)
        MODEL.test(testloader)
    DATASET = dataset.SplitMnistFashion
    f_fashion = getF(NUM_TASK,DATASET,MODEL)
    print('mnist trs value:',MODEL.target_net.trs(f_mnist.squeeze()),',mnist fashion trs value:',MODEL.target_net.trs(f_fashion.squeeze()))