from __future__ import print_function
import nni, os, sys
import argparse
import warnings
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from fvcore.nn import FlopCountAnalysis, parameter_count_table

sys.path.append('../src')
os.chdir("../src")
from utils import *
from dataset.dataset import *
from detector_utils.model import *

warnings.simplefilter(action='ignore', category=UserWarning)

def train(args, model, device, train_loader, optimizer, epoch, logger, start_index):
    model.train()
    for batch_idx, (weight, label) in enumerate(train_loader, start_index):
        weight, label = weight.to(device), label.to(device)
        optimizer.zero_grad()
        output = model(weight)
        output = output.view(output.shape[0] * output.shape[1],output.shape[2])
        label = label.view(label.shape[0] * label.shape[1])
        loss = F.cross_entropy(output, label)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx - start_index) * len(weight), len(train_loader.dataset),
                100. * (batch_idx - start_index) / len(train_loader), loss.item()))
            if args.dry_run:
                break
        if args.save_model_along and (batch_idx + 1) % args.save_model_interval == 0:
            torch.save(model.state_dict(), f"{args.exp_id}/{args.operation}_{batch_idx + 1}.pt")
            logger.info(f"model was saved to {args.exp_id}/{args.operation}_{batch_idx + 1}.pt")
    return model

def complexity_analysis(model, loader, logger):
    weight, _ = next(iter(loader))
    model.to("cpu")
    flops = FlopCountAnalysis(model, (weight[0:1],))
    logger.info(f"FLOPs: {flops.total()/1e3:.2f}K")
    logger.info(f"Parameters: {parameter_count_table(model)}")

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch Training')
    parser.add_argument('--exp_group', type=str, default="")
    parser.add_argument('--batch-size', type=int, default=1, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--task', type=str, default="vit_rerun", 
                        help='task identifier')
    parser.add_argument('--model', type=str, default="default", 
                        help='network structure')
    parser.add_argument('--pmodels', type=str, nargs = "+", default=["FC_G1_MNIST", "FC_G1_MNIST"], 
                        help='parent models in the dataset')
    parser.add_argument('--cmodels', type=str, nargs = "+", default=["FC_G2_FMNIST", "FC_G2_EMNIST_Letters"], 
                        help='child models in the dataset')
    parser.add_argument('--pre_train_ckpt', type=str, default="", 
                        help='path of the pretrained model')
    parser.add_argument('--test-batch-size', type=int, default=1, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N', # 100 MNIST pretrain, 5 Finetune
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7, 1.0 for fewshot)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1314, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--embed_dim', type=int, default=32, metavar='N',
                        help='how many hidden dims')
    parser.add_argument('--num_layers', type=int, default=0,
                        help='how many hidden layers')
    parser.add_argument('--num_encoder_layers', type=int, default=1,
                        help='how many hidden layers')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--save-model-interval', type = int, default=-1, # pretrain -1, finetune 100
                        help='wheter save model along training')
    parser.add_argument('--kfold_split', type = int, default=5, 
                        help='number of folds used to validate')
    args = parser.parse_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    if use_cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    randomness_control(args.seed)

    args.evaluation_flag = len(args.pre_train_ckpt) > 0
    args.operation = "Evaluation" if args.evaluation_flag else "Train"

    args.save_model_along = args.save_model_interval > 0

    args.exp_id = f"./log/"+"__".join(args.pmodels) + "___" + "__".join(args.cmodels) + \
                    f"/{args.model}_{args.embed_dim}_{args.num_encoder_layers}_{args.num_layers}" + \
                        f"_{args.batch_size}_{args.lr}_{args.gamma}_{args.seed}" 
    os.makedirs(args.exp_id, exist_ok = True)

    logger, formatter = get_logger(args.exp_id, None, "log.log", level=logging.INFO)

    train_loader, val_loader, test_loaders, full_loaders, num_p = get_phylogeny_loader(args)

    args.num_p = num_p

    model = get_model(args)
    model = model.to(device)

    if not args.evaluation_flag:
        
        optimizer = optim.Adam(model.parameters(), lr=args.lr)

        scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
        for epoch in range(1, args.epochs + 1):
            train(args, model, device, train_loader, optimizer, epoch, logger = logger, start_index = (epoch - 1) *len(train_loader))
            scheduler.step()
            if args.dry_run:
                break

        if args.save_model:
            torch.save(model.state_dict(), f"{args.exp_id}/{args.operation}.pt")
            logger.info(f"model was saved to {args.exp_id}/{args.operation}.pt")

        logger.info(f"training process was finished")
    else:
        pass

    complexity_analysis(model, train_loader, logger)

if __name__ == '__main__':
    main()
