import os
import shutil

import torch.utils.data
# import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import argparse
import re

from helpers import makedir
import push, model, save, train_and_test as tnt
from log import create_logger
from preprocess import mean, std, preprocess_input_function
from FullModels import base_architecture_to_features,construct_VLNet
import LanguageModels
import settings_CUB
import image_caption_dataset

parser = argparse.ArgumentParser()
parser.add_argument('-gpuid',type=str, default='1,2,3,7') # python3 main.py -gpuid=0,1,2,3
parser.add_argument('-arch',type=str, default='vgg19') # vgg19

parser.add_argument('-dataset',type=str,default="CUB")
#coeff 参数控制
parser.add_argument('-update_projector_freq',type=int,default=50)
parser.add_argument('-start_align_epoch',type=int,default=15)
parser.add_argument('-coeff_ce',type=float,default=1)
parser.add_argument('-coeff_clst',type=float,default=0.8)
parser.add_argument('-coeff_sep',type=float,default=-0.08)
parser.add_argument('-coeff_l1',type=float,default=1e-4)
parser.add_argument('-alpha',type=float,default=0.5)

parser.add_argument('-momentum_G',type=float,default=-0.9)


parser.add_argument('-times',type=str,default="test1108_nocrop",help="experiment_run")
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpuid
print(os.environ['CUDA_VISIBLE_DEVICES'])



#setting parameter
experiment_run = args.times + "start_MA_"+ str(args.start_align_epoch) + "_alpha_" + str(args.alpha) + "_momentum_G_" + str(args.momentum_G) + "_update_freq_" + str(args.update_projector_freq)
base_architecture = args.arch
dataset_name = args.dataset


base_architecture_type = re.match('^[a-z]*', base_architecture).group(0)  #VGG

model_dir = './saved_models/' + dataset_name+'/' + base_architecture + '/' + experiment_run + '/'

makedir(model_dir)
shutil.copy(src=os.path.join(os.getcwd(), __file__), dst=model_dir)
shutil.copy(src=os.path.join(os.getcwd(), 'settings_CUB.py'), dst=model_dir)
shutil.copy(src=os.path.join(os.getcwd(), base_architecture_type + '_features.py'), dst=model_dir)
shutil.copy(src=os.path.join(os.getcwd(), 'FullModels.py'), dst=model_dir)
shutil.copy(src=os.path.join(os.getcwd(), 'train_and_test.py'), dst=model_dir)

log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log'))
img_dir = os.path.join(model_dir, 'img')
makedir(img_dir)
weight_matrix_filename = 'outputL_weights'
prototype_img_filename_prefix = 'prototype-img'
prototype_self_act_filename_prefix = 'prototype-self-act'
proto_bound_boxes_filename_prefix = 'bb'
# load the hyper param
if dataset_name == "CUB":
    #model param
    num_classes = settings_CUB.num_classes
    img_size = settings_CUB.img_size
    add_on_layers_type = settings_CUB.add_on_layers_type
    arch_name = str(args.arch).upper()
    if arch_name.startswith("VGG"):
        prototype_shape = (2000,128,1,1)
    elif arch_name.startswith("RES") or arch_name.startswith("DENSE"):
        prototype_shape = (2000, 64, 1, 1)
        #prototype_shape = (2000, 128, 1, 1)
    prototype_activation_function = settings_CUB.prototype_activation_function
    #datasets
    train_dir = settings_CUB.train_dir
    test_dir = settings_CUB.test_dir
    train_push_dir = settings_CUB.train_push_dir
    train_batch_size = settings_CUB.train_batch_size
    test_batch_size = settings_CUB.test_batch_size
    train_push_batch_size = settings_CUB.train_push_batch_size
    #optimzer
    joint_optimizer_lrs = settings_CUB.joint_optimizer_lrs
    joint_lr_step_size = settings_CUB.joint_lr_step_size
    warm_optimizer_lrs = settings_CUB.warm_optimizer_lrs

    last_layer_optimizer_lr = settings_CUB.last_layer_optimizer_lr
    # weighting of different training losses
    coefs = settings_CUB.coefs
    # number of training epochs, number of warm epochs, push start epoch, push epochs
    num_train_epochs = settings_CUB.num_train_epochs
    num_warm_epochs = settings_CUB.num_warm_epochs
    push_start = settings_CUB.push_start
    push_epochs = settings_CUB.push_epochs


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
train_transforms = transforms.Compose([
        transforms.Resize(size=(224, 224)),
        transforms.ToTensor(),
        normalize,
    ])
val_transforms = transforms.Compose([
        transforms.Resize(size=(224,224)),
        transforms.ToTensor(),
        normalize,
    ])
train_push_dataset = transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
    ])

# all datasets
# train set
data_path = ""
#crop datasets
train_dataset = image_caption_dataset.CUB200_CAPTION_AUG(root=data_path,train=True,transform=train_transforms, cropped = True, resize=224)
test_dataset = image_caption_dataset.CUB200_CAPTION_AUG(root=data_path,train=False,transform=val_transforms, cropped = True, resize=224)
#train_dataset = image_caption_dataset.CUB200_CAPTION_AUG(root=data_path,train=True,transform=train_transforms, cropped = False, resize=224)
#test_dataset = image_caption_dataset.CUB200_CAPTION_AUG(root=data_path,train=False,transform=val_transforms, cropped = False, resize=224)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True,
    num_workers=4, pin_memory=False)
# push set
train_push_dataset = datasets.ImageFolder(
    train_push_dir,
    transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
    ]))
train_push_loader = torch.utils.data.DataLoader(
    train_push_dataset, batch_size=train_push_batch_size, shuffle=False,
    num_workers=4, pin_memory=False)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False,
    num_workers=4, pin_memory=False)

# we should look into distributed sampler more carefully at torch.utils.data.distributed.DistributedSampler(train_dataset)
log('training set size: {0}'.format(len(train_loader.dataset)))
log('push set size: {0}'.format(len(train_push_loader.dataset)))
log('test set size: {0}'.format(len(test_loader.dataset)))
log('batch size: {0}'.format(train_batch_size))

# construct the model
# create model
image_model = base_architecture_to_features[args.arch](pretrained=True)
language_model = LanguageModels.Bert_base()

ppnet = construct_VLNet(args, image_model, language_model,  img_size=img_size,
            prototype_shape=prototype_shape, num_classes=num_classes,
            prototype_activation_function= prototype_activation_function,
            add_on_layers_type= add_on_layers_type, mode = -1)

ppnet = ppnet.cuda()
ppnet_multi = torch.nn.DataParallel(ppnet)
class_specific = False

# define optimizer

joint_optimizer_specs = \
[{'params': ppnet.image_model.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized
 {'params': ppnet.language_model.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized
 {'params': ppnet.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet.language_projection_head.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']},
]
joint_optimizer = torch.optim.Adam(joint_optimizer_specs)
joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.1)


warm_optimizer_specs = \
[{'params': ppnet.add_on_layers.parameters(), 'lr': warm_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet.language_projection_head.parameters(), 'lr': warm_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
 {'params': ppnet.prototype_vectors, 'lr': warm_optimizer_lrs['prototype_vectors']},
]
warm_optimizer = torch.optim.Adam(warm_optimizer_specs)


last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': last_layer_optimizer_lr}]
last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs)


log('start training')
for epoch in range(num_train_epochs):#1000
    log('epoch: \t{0}'.format(epoch))

    if epoch < args.start_align_epoch:
        log("undo manifold alignment")
        ppnet_multi.module.change_mode(-1)
    else:
        log("doing manifold alignment")
        ppnet_multi.module.change_mode(1)
    #stage 1: SGD of layers before the last
    if epoch < num_warm_epochs:
        tnt.warm_only(model=ppnet_multi, log=log)
        
        _ = tnt.train(args=args,model=ppnet_multi, dataloader=train_loader, optimizer=warm_optimizer,
                      class_specific=class_specific, coefs=coefs, log=log)
    else:
        tnt.joint(model=ppnet_multi, log=log)
        joint_lr_scheduler.step()
        _ = tnt.train(args=args,model=ppnet_multi, dataloader=train_loader, optimizer=joint_optimizer,
                      class_specific=class_specific, coefs=coefs, log=log)

    accu = tnt.test(args=args,model=ppnet_multi, dataloader=test_loader,
                    class_specific=class_specific, coefs=coefs, log=log)
    save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'nopush', accu=accu,
                                target_accu=0.70, log=log)

   
logclose()

