import  torch, os
import  numpy as np
import  scipy.stats
from    torch.utils.data import DataLoader
from    torch.optim import lr_scheduler
import  random, sys, pickle
import  argparse

from meta_PN import Meta
from dataloader_train import miniImageNetGenerator as Loader_train
from dataloader_test import miniImageNetGenerator as Loader_test
from functions import *
from tqdm import tqdm

torch.set_num_threads(1)

def mean_confidence_interval(accs, confidence=0.95):
    n = accs.shape[0]
    m, se = np.mean(accs), scipy.stats.sem(accs)
    h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)
    return m, h

data_path = '/drive1/YH/datasets/miniImageNet/'
train_path = data_path + '/miniImageNet_category_split_train_phase_train.pickle'
val_path = data_path + '/miniImageNet_category_split_val.pickle'
test_path = data_path + '/miniImageNet_category_split_test.pickle'

def main():

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32])
        # ('avg_pool2d', [6, 1, 0])
        # ('flatten', []),
        # ('linear', [args.n_way, 32 * 5 * 5])
    ]

    device = torch.device('cuda:'+str(args.gpu))
    args.device = device
    maml = Meta(args, config).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print('Total trainable tensors:', num)

    loader_train = Loader_train(data_file=train_path, nb_classes=args.n_way, num_user = args.n_user,
                                n_spt=args.n_spt, n_qry = args.n_qry, max_iter=10001)
    loader_test = Loader_test(data_file=test_path, nb_classes=args.n_way, num_user = args.n_user,
                                n_spt=args.n_spt, n_qry = args.n_qry, max_iter=30)

    print("Data Loading Completed!")


    for step, x_spt, y_spt, x_qry, y_qry in tqdm(loader_train):
        for user in range(args.n_user):
            x_spt[user] = x_spt[user].to(device); y_spt[user] = y_spt[user].to(device)
            x_qry[user] = x_qry[user].to(device); y_qry[user] = y_qry[user].to(device)
        # x_spt : [10,30,3,84,84], x_qry : [90,3,84,84], y_spt : [10,30], y_qry : [90]

        accs = maml(step, x_spt, y_spt, x_qry, y_qry)

        if step % 40 == 0:
            print('step:', step, '\ttraining acc:', accs)

        if step % 500 == 0:
            if not os.path.isdir('save'):
                os.makedirs('save')
            torch.save(maml.state_dict(), 'save/%d_pth' % (step))

        if step % 400 == 0:  # evaluation
            accs_all_test = []
            loader_test.num_iter = 0
            for val_step, x_spt, y_spt, x_qry, y_qry  in loader_test:
                for user in range(args.n_user):
                    x_spt[user] = x_spt[user].to(device); y_spt[user] = y_spt[user].to(device)
                    x_qry[user] = x_qry[user].to(device); y_qry[user] = y_qry[user].to(device)
                # x_spt : [10,30,3,84,84], x_qry : [90,3,84,84], y_spt : [10,30], y_qry : [90]

                accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                accs_all_test.append(accs)

            # [b, update_step+1]
            accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
            print('step:', step, 'Test acc:', accs)

    loader_test = Loader_test(data_file=test_path, nb_classes=args.n_way, num_user=args.n_user,
                              n_spt=args.n_spt, n_qry=args.n_qry, max_iter=5000)
    accs_all_test = []
    for val_step, x_spt, y_spt, x_qry, y_qry in loader_test:
        for user in range(args.n_user):
            x_spt[user] = x_spt[user].to(device); y_spt[user] = y_spt[user].to(device)
            x_qry[user] = x_qry[user].to(device); y_qry[user] = y_qry[user].to(device)
        # x_spt : [10,30,3,84,84], x_qry : [90,3,84,84], y_spt : [10,30], y_qry : [90]

        accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
        accs_all_test.append(accs)
    accs = np.array(accs_all_test).mean(axis=0).astype(np.float)
    print('step:', step, 'Test acc:', accs)


if __name__ == '__main__':

    argparser = argparse.ArgumentParser()
    argparser.add_argument('--epoch', type=int, help='epoch number', default=1000)
    argparser.add_argument('--n_way', type=int, help='n way', default=5)
    argparser.add_argument('--n_spt', type=int, help='n shot for support set', default=6)
    argparser.add_argument('--n_qry', type=int, help='n shot for query set', default=6)

    argparser.add_argument('--imgsz', type=int, help='imgsz', default=84)
    argparser.add_argument('--imgc', type=int, help='imgc', default=3)
    argparser.add_argument('--n_user', type=int, help='meta batch size, namely task num', default=10)
    argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)
    argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.1)
    argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=1)
    argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
    argparser.add_argument('--gpu', type=int, help='gpu number', default=1)
    argparser.add_argument('--round', type=int, help='round number', default=2)


    args = argparser.parse_args()

    main()
