import argparse
import numpy as np
import importlib
import os
import sys
import socket
import setproctitle
from backbone.ResNet18 import resnet18
from backbone.alexnet import STLAlexNet
from backbone.convNet import fetch_net
conf_path = os.getcwd()
sys.path.append(conf_path)
from models import get_all_models
from argparse import ArgumentParser
from utils.args import add_management_args
from datasets import get_dataset
from models import get_model
from utils.conf import set_random_seed, get_device
from utils import create_if_not_exists
import torch
import uuid
import datetime
def lecun_fix():
    # Yann moved his website to CloudFlare. You need this now
    from six.moves import urllib
    opener = urllib.request.build_opener()
    opener.addheaders = [('User-agent', 'Mozilla/5.0')]
    urllib.request.install_opener(opener)

def parse_args():
    parser = ArgumentParser(description='mammoth', allow_abbrev=False)
    parser.add_argument('--model', type=str, required=True,
                        help='Model name.', choices=get_all_models())
    parser.add_argument('--epochs', type=int, default=50)


    # torch.set_num_threads(4)
    add_management_args(parser)

    # increment
    parser.add_argument('--increment', type=int, default=5, metavar='S',
                        help='(default: 5)')
    args = parser.parse_known_args()[0]
    mod = importlib.import_module('models.' + args.model)

    get_parser = getattr(mod, 'get_parser')
    parser = get_parser()
    args = parser.parse_args()

    if args.seed is not None:
        set_random_seed(args.seed)

    if args.model == 'mer': setattr(args, 'batch_size', 1)
    return args


def main(args=None):
    if args is None:
        args = parse_args()

    # job number
    args.conf_jobnum = str(uuid.uuid4())
    args.conf_timestamp = str(datetime.datetime.now())
    args.conf_host = socket.gethostname()

    dataset = get_dataset(args)
    total_class = dataset.N_CLASSES
    dataset.N_TASKS = args.n_tasks
    dataset.N_CLASSES_PER_TASK = int(total_class/args.n_tasks)

    if args.backbone == 'resnet18_lg':
        backbone = resnet18(dataset.N_CLASSES_PER_TASK * dataset.N_TASKS, nf=64)
    if args.backbone == 'resnet18':
        backbone = resnet18(dataset.N_CLASSES_PER_TASK * dataset.N_TASKS, nf=20)
    if args.backbone == 'resnet18_sm':
        backbone = resnet18(dataset.N_CLASSES_PER_TASK * dataset.N_TASKS, nf=10)

    if args.lamb > 0:
        if args.teacher_backbone == 'convnet_real':
            distill_backbone = fetch_net(args, dataset.N_TASKS, dataset.N_TASKS * dataset.N_CLASSES_PER_TASK, 0.2)
        elif args.teacher_backbone == 'resnet18':
            distill_backbone = resnet18(dataset.N_CLASSES_PER_TASK * dataset.N_TASKS, nf=20)
        elif args.teacher_backbone == 'resnet18_lg':
            distill_backbone = resnet18(dataset.N_CLASSES_PER_TASK * dataset.N_TASKS, nf=64)
        distill_loss = dataset.get_loss()
        distill_model = get_model(args, distill_backbone, distill_loss, dataset.get_transform(), "teachers")

    loss = dataset.get_loss()
    model = get_model(args, backbone, loss, dataset.get_transform())
    args.distributed = None
    import wandb
    os.environ["WANDB_API_KEY"] = "YOUR_WANDB_KEY"
    wandb.init(project=f"{args.name}_{args.dataset}_{args.model}_{args.backbone}_t{args.n_tasks}", config=args)

    from utils.training import train
    train(model, dataset, args, distill_model=distill_model)

if __name__ == '__main__':
    main()
