# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import time
from alg.opt import *
from alg import alg, modelopera
from utils.util import set_random_seed, get_args, print_row, print_args, train_valid_target_eval_names, alg_loss_dict, print_environ, save_checkpoint
from datautil.getdataloader_single import get_act_dataloader
import os
from plot_a_distance import proxy_a_distance, proxy_mlp_a_distance

def get_a_distance(args, algorithm, train_loaders_not_infinity, target_loader, device):
    algorithm.load_state_dict(torch.load(os.path.join(args.output, 'model.pkl')))
    source_X = []
    target_X = []
    for data in train_loaders_not_infinity:
        x, y, d, pctarget, pdtarget, index = data
        x, y, d = x.float().to(device), y.to(device), d.to(device)
        # print(algorithm.get_features(x).shape)
        # break
        features = algorithm.get_features(x)
        source_X.append(features)
    
    for data in target_loader:
        x, y, d, pctarget, pdtarget, index = data
        x, y, d = x.float().to(device), y.to(device), d.to(device)
        features = algorithm.get_features(x)
        target_X.append(features)

    source_X = torch.cat(source_X, dim=0).cpu().detach().numpy()
    target_X = torch.cat(target_X, dim=0).cpu().detach().numpy()
    print(source_X.shape, target_X.shape)

    # result = proxy_a_distance(source_X, target_X)
    result = proxy_mlp_a_distance(source_X, target_X)
    print("A-distance:", result)
    return result

def main(args):
    s = print_args(args, [])
    set_random_seed(args.seed)

    print_environ()
    print(s)
    if args.latent_domain_num < 6:
        args.batch_size = 32*args.latent_domain_num
    else:
        args.batch_size = 16*args.latent_domain_num
    args.batch_size = 128
 
    train_loader, train_loader_noshuffle, valid_loader, target_loader, _, _, _ = get_act_dataloader(
        args)

    best_valid_acc, target_acc = 0, 0
    precision, recall, f1 = 0, 0, 0
    device = torch.device("cuda:{}".format(args.gpu_id))
    algorithm_class = alg.get_algorithm_class(args.algorithm)
    algorithm = algorithm_class(args).to(device)
    algorithm.train()
    optd = get_optimizer(algorithm, args, nettype='Diversify-adv')
    opt = get_optimizer(algorithm, args, nettype='Diversify-cls')
    opta = get_optimizer(algorithm, args, nettype='Diversify-all')
    
    feature_speed_list = []
    latent_speed_list = []
    domain_invariant_speed_list = []
    iter_round = 100
    # time_now = time.time()

    # for round in range(args.max_epoch):
    #     print(f'\n========ROUND {round}========')
    #     print('====Feature update====')
    #     loss_list = ['class']
    #     print_row(['epoch']+[item+'_loss' for item in loss_list], colwidth=15)

    #     itercnt = 0
    #     for step in range(args.local_epoch):
    #         for data in train_loader:
    #             loss_result_dict = algorithm.update_a(data, opta)
    #             itercnt += 1
    #             # calculate speed
    #             if (itercnt + 1) % iter_round == 0 or itercnt == 0:
    #                 speed = (time.time() - time_now) * 1000 / iter_round 
    #                 time_now = time.time()
    #                 feature_speed_list.append(speed)
        
    #         print_row([step]+[loss_result_dict[item]
    #                           for item in loss_list], colwidth=15)

    #     print('====Latent domain characterization====')
    #     loss_list = ['total', 'dis', 'ent']
    #     print_row(['epoch']+[item+'_loss' for item in loss_list], colwidth=15)

    #     itercnt = 0
    #     for step in range(args.local_epoch):
    #         for data in train_loader:
    #             loss_result_dict = algorithm.update_d(data, optd)
    #             itercnt += 1
    #             if (itercnt + 1) % iter_round == 0 or itercnt == 0:
    #                 speed = (time.time() - time_now) * 1000 / iter_round 
    #                 time_now = time.time()
    #                 latent_speed_list.append(speed)
    #         print_row([step]+[loss_result_dict[item]
    #                           for item in loss_list], colwidth=15)

    #     algorithm.set_dlabel(train_loader)

    #     print('====Domain-invariant feature learning====')

    #     loss_list = alg_loss_dict(args)
    #     eval_dict = train_valid_target_eval_names(args)
    #     print_key = ['epoch']
    #     print_key.extend([item+'_loss' for item in loss_list])
    #     print_key.extend([item+'_acc' for item in eval_dict.keys()])
    #     print_key.append('total_cost_time')
    #     print_row(print_key, colwidth=15)

    #     itercnt = 0
    #     sss = time.time()
    #     for step in range(args.local_epoch):
    #         for data in train_loader:
    #             step_vals = algorithm.update(data, opt)
    #             itercnt += 1
    #             if (itercnt + 1) % iter_round == 0 or itercnt == 0:
    #                 speed = (time.time() - time_now) * 1000 / iter_round 
    #                 time_now = time.time()
    #                 domain_invariant_speed_list.append(speed)

    #         results = {
    #             'epoch': step,
    #         }

    #         results['train_acc'] = modelopera.accuracy(
    #             algorithm, train_loader_noshuffle, None)

    #         acc, _ = modelopera.accuracy(algorithm, valid_loader, None)
    #         results['valid_acc'] = acc

    #         acc, metic = modelopera.accuracy(algorithm, target_loader, None)
    #         results['target_acc'] = acc
    #         results['precision'] = metic[0]
    #         results['recall'] = metic[1]
    #         results['f1'] = metic[2]

    #         for key in loss_list:
    #             results[key+'_loss'] = step_vals[key]
    #         if results['valid_acc'] > best_valid_acc:
    #             best_valid_acc = results['valid_acc']
    #             target_acc = results['target_acc']
    #             precision = results['precision']
    #             recall = results['recall']
    #             f1 = results['f1']
    #             # torch.save(algorithm, 'model.pt')
    #             # torch.save(algorithm.state_dict(), 'model.pth')
    #             # save_checkpoint(algorithm, args, 'model.pkl')
    #             save_dict = algorithm.state_dict()
    #             torch.save(save_dict, os.path.join(args.output, 'model.pkl'))
    #         results['total_cost_time'] = time.time()-sss
    #         print_row([results[key] for key in print_key], colwidth=15)


    # feature_avg_speed = sum(feature_speed_list) / len(feature_speed_list) if len(feature_speed_list) != 0 else 0
    # latent_avg_speed = sum(latent_speed_list) / len(latent_speed_list) if len(latent_speed_list) != 0 else 0
    # domain_invariant_avg_speed = sum(domain_invariant_speed_list) / len(domain_invariant_speed_list) if len(domain_invariant_speed_list) != 0 else 0

    # # total iter
    # total_avg_speed = feature_avg_speed + latent_avg_speed + domain_invariant_avg_speed


    # print(f'Target acc: {target_acc:.4f}')
    # print(f'Precision: {precision:.4f}')
    # print(f'Recall: {recall:.4f}')
    # print(f'F1: {f1:.4f}')
    # print("mean speed: {:.2f}ms/iter".format(total_avg_speed))
    a = get_a_distance(args, algorithm, train_loader, target_loader, device)
    print(a)


if __name__ == '__main__':
    args = get_args()
    main(args)
