from utils.util import read_yaml, get_backbone, load_state_dict
from dataset import data_process
import torchvision.transforms as transforms
import torch
from utils.cca_util import CCAHook

class SvccaEvaluation():
    def __init__(self, opt):
        self.root = opt.root
        self.dataset = opt.dataset
        self.device = opt.device
        self.backbone = opt.backbone
        self.num_classes = None
        self.conf = read_yaml(opt.data_conf, self.backbone, self.dataset)
        self.load = {'pretrain': opt.pretrain, 'standard': opt.standard, 'finetune': opt.finetune}
        self.layer_name = opt.layer_name

    def data_process(self):
        batch_size = self.conf['batch_size']
        train_loader, test_loader = data_process(root=self.root, dataset=self.dataset,
                                                 batch_size=batch_size, train=False)
        self.num_classes = test_loader.dataset.num_classes
        return train_loader, test_loader

    def net_process(self):
        backbone = get_backbone(self.backbone)
        if self.load['pretrain'] == 'imagenet':
            pretrain_net = backbone(pretrained=True)
        else:
            pretrain_net = backbone(pretrained=False, num_classes=self.num_classes)
            load_state_dict(pretrain_net, self.load['pretrain'])

        standard_net = backbone(pretrained=False, num_classes=self.num_classes)
        load_state_dict(standard_net, self.load['standard'])

        finetune_net = backbone(pretrained=False, num_classes=self.num_classes)
        load_state_dict(finetune_net, self.load['finetune'])

        return pretrain_net, standard_net, finetune_net

    def eval(self):
        print('-------------svcca eval-------------')
        print('dataset: {}\tbackbone: {}'.format(self.dataset, self.backbone))

        _, test_loader = self.data_process()

        pretrain_net, standard_net, finetune_net = self.net_process()

        pretrain_net.to(self.device).eval()
        standard_net.to(self.device).eval()
        finetune_net.to(self.device).eval()

        standard = CCAHook(standard_net, self.layer_name, device=self.device)
        finetune = CCAHook(finetune_net, self.layer_name, device=self.device)
        pretrain = CCAHook(pretrain_net, self.layer_name, device=self.device)

        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        for images, labels in test_loader:
            images, labels = images.to(self.device), labels.to(self.device)
            images = normalize(images)
            with torch.no_grad():
                pretrain_net(images)
                standard_net(images)
                finetune_net(images)


        cca_value = [standard.cca_similar(finetune).item(),
                       standard.cca_similar(pretrain).item(),
                       finetune.cca_similar(pretrain).item()]

        print('cca:')
        print('standard and finetune:{}'.format(cca_value[0]))
        print('standard and pretrain:{}'.format(cca_value[1]))
        print('finetune and pretrain:{}'.format(cca_value[2]))


import argparse
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--backbone', type=str, default='resnet18')
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--data_conf', type=str, default='conf.yaml')
    parser.add_argument('--root', type=str, default='/data/jiaming/datasets')

    parser.add_argument('--pretrain', type=str, default='imagenet')
    parser.add_argument('--standard', type=str)
    parser.add_argument('--finetune', type=str)
    parser.add_argument('--layer_name', type=str, default='layer4')

    parser.add_argument('--device', type=int, default=0)

    opt = parser.parse_args()

    evaluater = SvccaEvaluation(opt)
    evaluater.eval()
