from __future__ import print_function
import os, sys
PATH_CUR = os.path.dirname(os.path.realpath(__file__))
sys.path.append(PATH_CUR)
sys.path.append('../')
import argparse
# import cPickle as pickle
import pickle
import random
import numpy as np
import csv
import tensorboardX as tbX
import torch
torch.cuda.set_device(1)
log_dir = './sort_MITR_4N4H_wo_workingmemory_youhx1hx2_day1115'
summary_writer = tbX.SummaryWriter(log_dir)

# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # 将1替换为要使用的GPU索引
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from torch.autograd import Variable
from model import RN, CNN_MLP, Transformer


def str2bool(v):
    """Method to map string to bool for argument parser"""
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    if v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    raise argparse.ArgumentTypeError('Boolean value expected.')


parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example')
parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP', 'Transformer'], default='Transformer',
                    help='resume from model stored')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
                    help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
                    help='learning rate (default: 0.0001)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
# parser.add_argument('--seed', type=int, default=1, metavar='S',
#                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=300, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--resume', type=str,
                    help='resume from model stored')
parser.add_argument('--relation-type', type=str, default='binary',
                    help='what kind of relations to learn. options: binary, ternary (default: binary)')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
# 1 ~9
# TR + HSW 256 4 True True 5 True 8 1 False
parser.add_argument('--embed_dim', type=int, default=256)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--num_heads', type=int, default=4)
# 层之间参数是否共享，TR+HC和ISAB 不会共享，其他都是共享
parser.add_argument('--share_vanilla_parameters', type=str2bool, default=True)  # default=False
parser.add_argument('--use_topk', type=str2bool, default=True)  # default=False
parser.add_argument('--topk', type=int, default=5)  # default=3
parser.add_argument('--shared_memory_attention', type=str2bool, default=True) # default=False
parser.add_argument('--mem_slots', type=int, default=8)  # default=4
parser.add_argument('--use_long_men', type=str2bool, default=True,
                    help='ues long-term memory or not')
parser.add_argument('--long_mem_segs', type=int, default=5)
parser.add_argument('--long_mem_aggre', type=str2bool, default=False,
                    help='uses cross-attention between WM and LTM or not')
parser.add_argument('--use_wm_inference', type=str2bool, default=False,
                    help='WM involvement during inference or not')
parser.add_argument('--seed', type=int, default=1)  # default=0
parser.add_argument('--functional', type=str2bool, default=False,
                    help='ues set_transformer or not') # default=False
parser.add_argument('--save_dir', type=str, default='model_zxycuda')
parser.add_argument('--null_attention', type=str2bool, default=False)


# 3.使用 parse_args() 解析添加的参数
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
np.random.seed(args.seed)


args.image_size = 75
args.patch_size = 15

# 调用模型
if args.model == 'CNN_MLP':
    model = CNN_MLP(args)
elif args.model == 'Transformer':
    model = Transformer(args)
    # print(model)
else:
    model = RN(args)

model_dirs = args.save_dir
bs = args.batch_size
input_img = torch.FloatTensor(bs, 3, 75, 75)
input_qst = torch.FloatTensor(bs, 18)
label = torch.LongTensor(bs)

if args.cuda:
    model.cuda()
    input_img = input_img.cuda()
    input_qst = input_qst.cuda()
    label = label.cuda()

input_img = Variable(input_img)
input_qst = Variable(input_qst)
label = Variable(label)


def tensor_data(data, i):
    img = torch.from_numpy(np.asarray(data[0][bs * i:bs * (i + 1)])) # 64,3,75,75
    qst = torch.from_numpy(np.asarray(data[1][bs * i:bs * (i + 1)])) # 64,18
    ans = torch.from_numpy(np.asarray(data[2][bs * i:bs * (i + 1)])) # 64,

    input_img.data.resize_(img.size()).copy_(img)
    input_qst.data.resize_(qst.size()).copy_(qst)
    label.data.resize_(ans.size()).copy_(ans)


def cvt_data_axis(data):
    img = [e[0] for e in data]
    qst = [e[1] for e in data]
    ans = [e[2] for e in data]
    return (img, qst, ans)

# 模型训练
def train(epoch, ternary, rel, norel):
    model.train()

    if not len(rel[0]) == len(norel[0]):
        print('Not equal length for relation dataset and non-relation dataset.')
        return

    random.shuffle(ternary)
    random.shuffle(rel)
    random.shuffle(norel)

    # ternary、rel、norel都为三元组形式：图片、问题、答案
    ternary = cvt_data_axis(ternary)
    rel = cvt_data_axis(rel)
    norel = cvt_data_axis(norel)

    acc_ternary = []
    acc_rels = []
    acc_norels = []

    l_ternary = []
    l_binary = []
    l_unary = []

    # rel[0]为98000个（3,75,75）的图片
    for batch_idx in range(len(rel[0]) // bs):
        # 三元关系
        tensor_data(ternary, batch_idx)
        # input_img为（3,75,75）
        accuracy_ternary, loss_ternary = model.train_(input_img, input_qst, label)
        acc_ternary.append(accuracy_ternary.item())
        l_ternary.append(loss_ternary.item())
        # 关系型
        tensor_data(rel, batch_idx)
        accuracy_rel, loss_binary = model.train_(input_img, input_qst, label)
        acc_rels.append(accuracy_rel.item())
        l_binary.append(loss_binary.item())
        # 非关系型
        tensor_data(norel, batch_idx)
        accuracy_norel, loss_unary = model.train_(input_img, input_qst, label)
        acc_norels.append(accuracy_norel.item())
        l_unary.append(loss_unary.item())

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)] '
                  'Ternary accuracy: {:.0f}% | Relations accuracy: {:.0f}% | Non-relations accuracy: {:.0f}%'.format(
                epoch,
                batch_idx * bs * 2,
                len(rel[0]) * 2,
                100. * batch_idx * bs / len(rel[0]),
                accuracy_ternary,
                accuracy_rel,
                accuracy_norel), flush=True)

    avg_acc_ternary = sum(acc_ternary) / len(acc_ternary)
    avg_acc_binary = sum(acc_rels) / len(acc_rels)
    avg_acc_unary = sum(acc_norels) / len(acc_norels)

    summary_writer.add_scalars('Accuracy/train', {
       'ternary': avg_acc_ternary,
       'binary': avg_acc_binary,
       'unary': avg_acc_unary
    }, epoch)


    avg_loss_ternary = sum(l_ternary) / len(l_ternary)
    avg_loss_binary = sum(l_binary) / len(l_binary)
    avg_loss_unary = sum(l_unary) / len(l_unary)

    summary_writer.add_scalars('Loss/train', {
       'ternary': avg_loss_ternary,
       'binary': avg_loss_binary,
       'unary': avg_loss_unary
    }, epoch)

    # return average accuracy
    return avg_acc_ternary, avg_acc_binary, avg_acc_unary


# 模型测试
def test(epoch, ternary, rel, norel):
    model.eval()
    if not len(rel[0]) == len(norel[0]):
        print('Not equal length for relation dataset and non-relation dataset.', flush=True)
        return

    ternary = cvt_data_axis(ternary)
    rel = cvt_data_axis(rel)
    norel = cvt_data_axis(norel)

    accuracy_ternary = []
    accuracy_rels = []
    accuracy_norels = []

    loss_ternary = []
    loss_binary = []
    loss_unary = []

    for batch_idx in range(len(rel[0]) // bs):
        tensor_data(ternary, batch_idx)
        acc_ter, l_ter = model.test_(input_img, input_qst, label)
        accuracy_ternary.append(acc_ter.item())
        loss_ternary.append(l_ter.item())

        tensor_data(rel, batch_idx)
        acc_bin, l_bin = model.test_(input_img, input_qst, label)
        accuracy_rels.append(acc_bin.item())
        loss_binary.append(l_bin.item())

        tensor_data(norel, batch_idx)
        acc_un, l_un = model.test_(input_img, input_qst, label)
        accuracy_norels.append(acc_un.item())
        loss_unary.append(l_un.item())

    accuracy_ternary = sum(accuracy_ternary) / len(accuracy_ternary)
    accuracy_rel = sum(accuracy_rels) / len(accuracy_rels)
    accuracy_norel = sum(accuracy_norels) / len(accuracy_norels)
    print('\n Test set: Ternary accuracy: {:.0f}% Binary accuracy: {:.0f}% | Unary accuracy: {:.0f}%\n'.format(
        accuracy_ternary, accuracy_rel, accuracy_norel), flush=True)

    summary_writer.add_scalars('Accuracy/test', {
       'ternary': accuracy_ternary,
       'binary': accuracy_rel,
       'unary': accuracy_norel
    }, epoch)

    loss_ternary = sum(loss_ternary) / len(loss_ternary)
    loss_binary = sum(loss_binary) / len(loss_binary)
    loss_unary = sum(loss_unary) / len(loss_unary)

    summary_writer.add_scalars('Loss/test', {
       'ternary': loss_ternary,
       'binary': loss_binary,
       'unary': loss_unary
    }, epoch)

    return accuracy_ternary, accuracy_rel, accuracy_norel


# 数据加载
def load_data():
    print('loading data...')
    dirs = './data'
    filename = os.path.join(dirs, 'sort-of-clevr.pickle')
    with open(filename, 'rb') as f:
        train_datasets, test_datasets = pickle.load(f)
    ternary_train = []
    ternary_test = []
    rel_train = []
    rel_test = []
    norel_train = []
    norel_test = []
    print('processing data...', flush=True)

    for img, ternary, relations, norelations in train_datasets:
        img = np.swapaxes(img, 0, 2)
        for qst, ans in zip(ternary[0], ternary[1]):
            ternary_train.append((img, qst, ans))
        for qst, ans in zip(relations[0], relations[1]):
            rel_train.append((img, qst, ans))
        for qst, ans in zip(norelations[0], norelations[1]):
            norel_train.append((img, qst, ans))

    for img, ternary, relations, norelations in test_datasets:
        img = np.swapaxes(img, 0, 2)
        for qst, ans in zip(ternary[0], ternary[1]):
            ternary_test.append((img, qst, ans))
        for qst, ans in zip(relations[0], relations[1]):
            rel_test.append((img, qst, ans))
        for qst, ans in zip(norelations[0], norelations[1]):
            norel_test.append((img, qst, ans))

    return (ternary_train, ternary_test, rel_train, rel_test, norel_train, norel_test)


if __name__ == "__main__":

    ternary_train, ternary_test, rel_train, rel_test, norel_train, norel_test = load_data()

    try:
        os.makedirs(model_dirs)
    except:
        print('directory {} already exists'.format(model_dirs), flush=True)

    if args.resume:
        filename = os.path.join(model_dirs, args.resume)
        if os.path.isfile(filename):
            print('==> loading checkpoint {}'.format(filename))
            checkpoint = torch.load(filename)
            model.load_state_dict(checkpoint)
            print('==> loaded checkpoint {}'.format(filename), flush=True)

    with open(f'{args.save_dir}/{args.model}_{args.seed}_log.csv', 'w') as log_file:
        csv_writer = csv.writer(log_file, delimiter=',')
        csv_writer.writerow(['epoch', 'train_acc_ternary', 'train_acc_rel',
                             'train_acc_norel', 'test_acc_ternary', 'test_acc_rel', 'test_acc_norel'])

        print(f"Training {args.model} {f'({args.relation_type})' if args.model == 'RN' else ''} model...", flush=True)
        best_test_acc = 0.0
        best_epoch = 0
        for epoch in range(1, args.epochs + 1):
            # 开始训练
            train_acc_ternary, train_acc_binary, train_acc_unary = train(
                epoch, ternary_train, rel_train, norel_train)
            torch.cuda.empty_cache()
            test_acc_ternary, test_acc_binary, test_acc_unary = test(
                epoch, ternary_test, rel_test, norel_test)
            if test_acc_binary > best_test_acc:
                best_test_acc = test_acc_binary
                best_epoch = epoch
                model.save_model(epoch, args.save_dir)
            csv_writer.writerow([epoch, train_acc_ternary, train_acc_binary,
                                 train_acc_unary, test_acc_ternary, test_acc_binary, test_acc_unary])
            # model.save_model(epoch, args.save_dir)
        print("Best Model: Epoch {}, Test Accuracy: {}".format(best_epoch, best_test_acc))
