
import sys
sys.path.append(".")
import numpy as np
import matplotlib.pyplot as plt
from domainbed import algorithms
from domainbed import datasets
from domainbed.lib import misc
from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader
import torch



data_dir = '../datasets_for_domainbed'

if torch.cuda.is_available():
    device = "cpu"
else:
    device = "cpu"


def get_model(save_path, alg_name):
    '''
    read model from file.
    return model, which can output y with input x.
        self.encoder_s(x)
        self.encoder_q(x)
        self.predict(x)
    '''
    save_dict = torch.load(save_path)  # load dict
    algorithm_class = algorithms.get_algorithm_class(
        alg_name)  # get algorithm class
    algorithm = algorithm_class(save_dict['model_input_shape'], save_dict['model_num_classes'],  # init algorithm # input_shape, num_classes, num_domains, hparams
                                save_dict['model_num_domains'], save_dict['model_hparams'])
    if save_dict['model_dict'] is not None:  # load model state dict
        algorithm.load_state_dict(save_dict['model_dict'])
    return algorithm


def get_datasets(dataset_name, hparams):
    if dataset_name in vars(datasets):
        dataset = vars(datasets)[dataset_name](data_dir,
                                               [0], hparams)  # TODO hparams dict
    print(f"Dataset(all) - Size of {len(dataset.ENVIRONMENTS)} datasets: ")
    misc.print_row(dataset.ENVIRONMENTS,
                   colwidth=12)
    misc.print_row([len(each) for each in dataset], colwidth=12)
    loaders = [FastDataLoader(  # TODO loader
        dataset=env,
        # weights=None,
        batch_size=32,
        num_workers=dataset.N_WORKERS)
        for i, env in enumerate(dataset)
    ]
    # train_minibatches_iterator = zip(*loaders)
    return loaders


'''paper 6
input: domains, encoder, transform(15, 19, random)
get m domains
encoder all datapoints encoded ps. with limited size
transform -> encoder -> trans_encoded
for each pair (trans_encoded, encoded)
    cal distance of size m-1
'''


def paper_6():
    return


'''paper 3
input: domains(2d points, 2-classication problem), encoder
draw m domains on graph
draw interface of classifer based on model 
'''


def paper_3():
    alg_name = 'EDG'
    algorithm = get_model(
        save_path='EXPS/Paper3/model_step400.pkl',
        alg_name=alg_name
    )
    hparams = {}
    hparams['test_type'] = 'args.test_type'
    hparams['env_distance'] = 10
    hparams['env_number'] = 10
    hparams['env_sample_number'] = 1000
    loaders = get_datasets('EDGCircle', hparams)
    for i, each_loader in enumerate(loaders):
        for data_i, (x, y) in enumerate(each_loader, 0):
            # print()
            # we need to load one domain as support here
            if alg_name == 'ERM':
                y_ = algorithm.predict(x).argmax(1)
            else:
                algorithm.eval_setup([loaders[-10]], device, None)
                y_ = algorithm.eval_predict(x).argmax(1)
            plt.scatter(x[:, 0].numpy(), x[:, 1].numpy(), c=[
                        (0.5, float(each/1), 0.5) for each in y_])
            if data_i > 1:
                break
    # TODO print points on graph with different colors for different domains
    # TODO print border line
    plt.savefig('EXPS/Paper3/tmp.png')
    return


'''paper 8
'''


def paper_8():
    domain_names = ['0', '15', '30', '45', '60', '75']
    data = [
        {
            'alg_name': 'FeatureCritic',
            'accs': [89.23, 99.68, 99.20, 99.24, 99.53, 91.44]
        },
        {
            'alg_name': 'HEX',
            'accs': [90.10, 98.90, 98.90, 98.80, 98.30, 90.00]
        },
        {
            'alg_name': 'CrossGrad',
            'accs': [88.30, 98.60, 98.00, 97.70, 97.70, 91.40]
        },
        {
            'alg_name': 'DIVA',
            'accs': [93.50, 99.30, 99.10, 99.20, 99.30, 93.00]
        },
        {
            'alg_name': 'ERM',
            'accs': [95.60, 99.00, 98.90, 99.10, 99.00, 96.70]
        },
    ]
    domain_num = len(data[0]['accs'])
    alg_num = len(data)
    total_width = 0.6

    font1 = {
        'family' : 'Times New Roman',
        'weight' : 'normal',
        'size'   : 14,
        'color': '#111111'
    }

    plt.clf()
    plt.figure()
    plt.subplot(111)

    x = np.arange(domain_num)
    width = total_width / alg_num
    x = x - (total_width - width) / 2

    for alg_i in range(alg_num):
        plt.bar(x+alg_i*width, data[alg_i]['accs'],
                width=width, label=data[alg_i]['alg_name'], edgecolor="#333333")

    # plt.yticks(fontproperties = 'Times New Roman', size = 12)
    # plt.xticks(fontproperties = 'Times New Roman', size = 12)
    plt.xlabel(r"Domain index/angle of RMNIST ($^\circ$)", font1)
    plt.ylabel("Error Rate (%)", font1)
    x1,x2,y1,y2 = plt.axis() # set y range
    plt.axis((x1,x2,85,110))
    plt.legend() # show legend
    locs, labels = plt.xticks()
    plt.xticks(x, domain_names)
    plt.yticks(np.arange(85, 101, 5))

    plt.savefig('EXPS/Paper8/v1.pdf')

    # TODO font


'''paper 9
'''

def paper_9():
    alg_names = ['FeatureCritic', 'HEX', 'CrossGrad', 'DIVA', 'ERM', 'Ours']
    data = [
        {
            'name': r'Forward - Target = $75^\circ$',
            'accs': [87.23, 93.44, 89.23, 92.44, 88.23, 96.9],
            'vars': [5, 5, 5, 5, 5, 5]
        },
        {
            'name': r'Backward - Target = $0^\circ$',
            'accs': [89.23, 91.94, 88.23, 93.44, 89.23, 97.3],
            'vars': [5, 5, 5, 5, 5, 5]
        }
    ]
    domain_num = len(data[0]['accs'])
    alg_num = len(data)
    total_width = 0.6

    font1 = {
        'family' : 'Times New Roman',
        'weight' : 'normal',
        'size'   : 14,
        'color': '#111111'
    }

    plt.clf()
    plt.figure()
    plt.subplot(111)

    x = np.arange(domain_num)
    width = total_width / alg_num
    x = x - (total_width - width) / 2

    for alg_i in range(alg_num):
        plt.bar(x+alg_i*width, data[alg_i]['accs'],
                width=width, label=data[alg_i]['name'], edgecolor="#333333")

    # plt.yticks(fontproperties = 'Times New Roman', size = 12)
    # plt.xticks(fontproperties = 'Times New Roman', size = 12)
    plt.xlabel(r"Algorithms", font1)
    plt.ylabel("Error Rate (%)", font1)
    x1,x2,y1,y2 = plt.axis() # set y range
    plt.axis((x1,x2,85,110))
    plt.legend() # show legend
    locs, labels = plt.xticks()
    plt.xticks(x, alg_names)
    plt.yticks(np.arange(85, 101, 5))

    plt.savefig('EXPS/Paper9/v1.pdf')

    # TODO font


# def paper_9():
#     alg_names = ['FeatureCritic', 'HEX', 'CrossGrad', 'DIVA', 'ERM', 'Ours']
#     aves = np.random.random(len(alg_names))
#     vars = np.random.random(len(alg_names)) / 4

#     domain_names = ['0', '15', '30', '45', '60', '75']
#     data = [
#         {
#             'alg_name': 'FeatureCritic',
#             'accs': [89.23, 99.68, 99.20, 99.24, 99.53, 91.44]
#         },
#         {
#             'alg_name': 'HEX',
#             'accs': [90.10, 98.90, 98.90, 98.80, 98.30, 90.00]
#         },
#         {
#             'alg_name': 'CrossGrad',
#             'accs': [88.30, 98.60, 98.00, 97.70, 97.70, 91.40]
#         },
#         {
#             'alg_name': 'DIVA',
#             'accs': [93.50, 99.30, 99.10, 99.20, 99.30, 93.00]
#         },
#         {
#             'alg_name': 'ERM',
#             'accs': [95.60, 99.00, 98.90, 99.10, 99.00, 96.70]
#         },
#     ]
#     domain_num = len(data[0]['accs'])
#     alg_num = len(data)
#     total_width = 0.6

#     font1 = {
#         'family' : 'Times New Roman',
#         'weight' : 'normal',
#         'size'   : 14,
#         'color': '#111111'
#     }

#     plt.clf()
#     plt.figure()
#     plt.subplot(111)

#     x = np.arange(len(alg_names))
#     # width = total_width / alg_num
#     # x = x - (total_width - width) / 2

#     plt.bar(x, aves, yerr=vars, width=0.3, capsize=0.5, label=data[1]['alg_name'], color="#d95319",edgecolor="#333333")

#     # plt.yticks(fontproperties = 'Times New Roman', size = 12)
#     # plt.xticks(fontproperties = 'Times New Roman', size = 12)
#     plt.xlabel(r"Domain index/angle of RMNIST ($^\circ$)", font1)
#     plt.ylabel("Error Rate (%)", font1)
#     x1,x2,y1,y2 = plt.axis() # set y range
#     # plt.axis((x1,x2,85,110))
#     # plt.legend() # show legend
#     locs, labels = plt.xticks()
#     plt.xticks(x, alg_names)
#     # plt.yticks(np.arange(85, 101, 5))

#     plt.savefig('EXPS/Paper9/tmp.png')




# ps. no t-sne?
if __name__ == '__main__':
    paper_3()
    # paper_6()
    # paper_8()
    # paper_9()
