import torch
from feature_extractor.for_image_data import models ####
import argparse
from feature_extractor.for_image_data import data_load
from DA_algorithms.DeepDA import models as adapted_models
from DA_algorithms.DeepDA.models import TransferNet
import numpy as np
import torch.nn as nn
import pickle
import os
parser = argparse.ArgumentParser(description='Extract')
parser.add_argument('--model_name', type=str,
                    help='model name', default='resnet50')
parser.add_argument('--batchsize', type=int, help='batch size', default=64)
parser.add_argument('--gpu', type=int, help='cuda id', default=0)
parser.add_argument('--source', type=str, default='webcam')
parser.add_argument('--target', type=str, default='dslr')
parser.add_argument('--save_dir', type=str, default='F:\OT_Score/feature_extractor\save_features')
parser.add_argument('--num_class', type=int, default=31)
parser.add_argument('--dataset_path', type=str,
                    default='G:\datasets\')
parser.add_argument('--finetuned_path', type=str,
                    default='F:\OT_Score/feature_extractor/for_image_data\save_model/best_resnet50_amazon.pth')
parser.add_argument('--adapted_path', type=str,
                    default='F:\OT_Score/feature_extractor/adapted_models\DSAN_webcam_to_dslr_resnet50.pth')
args = parser.parse_args()


def extract_feature(fintuned_model, adapted_model, src_dataloader, tar_dataloader, save_path):
    fintuned_model.eval()
    adapted_model.eval()
    correct = 0

    src_features = []
    src_probs = []
    src_labels = []
    tar_features = []
    tar_probs = []
    tar_labels = []
    tar_pred_labels = []

    with torch.no_grad():
        for inputs, labels in src_dataloader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            feas = fintuned_model.get_features(inputs)
            labels = labels.view(-1, 1).float()
            outputs = fintuned_model(inputs)
            probs = nn.Softmax(dim=1)(outputs)

            src_features.append(feas)
            src_probs.append(probs)
            src_labels.append(labels)

            preds = torch.max(outputs, 1)[1]
            correct += torch.sum(preds == labels.squeeze(1).long())
        src_acc = correct.double() / len(src_dataloader.dataset)
        print('src acc: %f' % src_acc.item())
        correct = 0
        for inputs, labels in tar_dataloader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            labels = labels.view(-1, 1).float()
            fintuned_feas = fintuned_model.get_features(inputs)
            fintuned_outputs = fintuned_model(inputs)
            fintuned_probs = nn.Softmax(dim=1)(fintuned_outputs)

            tar_features.append(fintuned_feas)
            tar_probs.append(fintuned_probs)
            tar_labels.append(labels)

            adapted_pred = torch.max(adapted_model.predict(inputs), 1)[1]
            tar_pred_labels.append(adapted_pred)

            correct += torch.sum(adapted_pred == labels.squeeze().long())
    tar_acc = correct.double() / len(tar_dataloader.dataset)
    print('tar Test acc: %f' % tar_acc.item())
    src_features = torch.cat(src_features, dim=0).cpu()
    src_probs = torch.cat(src_probs, dim=0).cpu()
    src_labels = torch.cat(src_labels, dim=0).cpu()

    tar_features = torch.cat(tar_features, dim=0).cpu()
    tar_probs = torch.cat(tar_probs, dim=0).cpu()
    tar_labels = torch.cat(tar_labels, dim=0).cpu()
    tar_pred_labels = torch.cat(tar_pred_labels, dim=0).cpu()
    with open(save_path, 'wb') as f:
        pickle.dump({
            'src_features': src_features,
            'src_probs': src_probs,
            'src_labels': src_labels,
            'tar_features': tar_features,
            'tar_probs': tar_probs,
            'tar_labels': tar_labels,
            'tar_pred_labels': tar_pred_labels
        }, f)



if __name__ == '__main__':
    torch.manual_seed(10)
    BATCH_SIZE = {'src': int(args.batchsize), 'tar': int(args.batchsize)}
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    finetuned_net = models.Network(base_net=args.model_name,
                                 n_class=args.num_class).to(DEVICE)
    finetuned_net.load_state_dict(torch.load(args.finetuned_path))
    finetuned_net.eval()

    adapted_net = adapted_models.TransferNet(
        args.num_class, base_net=args.model_name, use_bottleneck=False).to(DEVICE)
    adapted_net.load_state_dict(torch.load(args.adapted_path))
    adapted_net.eval()
    print('Loading data...')
    data_folder = args.dataset_path
    domain = {'src': str(args.source), 'tar': str(args.target)}
    dataloaders = {}
    data_test = data_load.load_data(
        data_folder + domain['tar'] + '/', BATCH_SIZE['tar'], 'test')  # images
    data_train = data_load.load_data(
        data_folder + domain['src'] + '/', BATCH_SIZE['src'], 'train', train_val_split=True, train_ratio=.99)  # images
    dataloaders['train'], dataloaders['val'], dataloaders['test'] = data_train[0], data_train[1], data_test
    print('Data loaded: Source: {}, Target: {}'.format(args.source, args.target))


    feature_save_path = os.path.join(args.save_dir, f"{args.source}2{args.target}_features.pkl")
    extract_feature(
        finetuned_net, adapted_net, dataloaders['train'], dataloaders['test'], feature_save_path)
    print('Deep tar_test features are extracted and saved!')


