import sys

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import MNIST


from trainer import MCTrainer
from classifier import MCClassificationSystem
from metrics import compute_metrics
from MetricsRecorder import MetricsRecorder
from models import TinyCNNs

import wandb

def define_metrics(run):
    run.define_metric("job")
    run.define_metric("jobID", summary="last", step_metric="job")
    run.define_metric("epoch", summary="last")
    run.define_metric("val/base_perfect_acc", summary="max", step_metric="epoch")
    run.define_metric("val/hex_perfect_acc", summary="max", step_metric="epoch")
    run.define_metric("val/base_leaves_acc", summary="max", step_metric="epoch")
    run.define_metric("val/hex_leaves_acc", summary="max", step_metric="epoch")
    run.define_metric("val/base_validity", summary="max", step_metric="epoch")
    run.define_metric("val/hex_validity", summary="max", step_metric="epoch")

def run(args):

    print("Test of MNIST classification")
    # Load the network and set device

    torch.manual_seed(args.seed)

    device="cuda:0"
    print(int(args.num_layers))
    network = TinyCNNs(num_classes=10, num_layers=int(args.num_layers))
    network.to(device)
    beta = args.beta
    jobID = args.job_ID

    classifier = MCClassificationSystem(model=network, beta=beta)

    if args.log:
        if args.online:
            from wandb_co import wandb_connect
            # Connect to the WandB run
            wandb_connect()
            mode="online"

        else:
            mode="offline"
        
        if args.load_checkpoint: # Load checkpoint if needed
            run = wandb.init(project="MNIST", entity="anonymous", id=args.load_run_ID, dir=args.log_path, resume="must",
                            mode=mode, settings=wandb.Settings(_disable_stats=True, _disable_meta=True))
            args.lr = float(run.config.get("lr"))
            if args.load_job_ID is None:
                loadJobID = run.summary.get("jobID")['last']
            else:
                loadJobID = args.load_job_ID
            load_path = '/gpfswork/rech/ifs/ukm19fl/output/' + str(loadJobID) + '_netweights.save'
            network.load_state_dict(torch.load(load_path, map_location='cpu'))
            define_metrics(run)
            if args.scheduler:
                args.lr = float(run.summary.get("lr")['last'])
                run.define_metric("lr", summary="last")
            run.log({"jobID":int(jobID)})
            args.epoch_offset = int(run.summary.get("epoch")['last'])+1
            hyperparams = run.config

        else:
            # Compute the model size and the number of FLOPS
            args.model_size = sum(p.numel() for p in network.parameters() if p.requires_grad)
            
            # Launch wandb run and define metrics
            run = wandb.init(project="MNIST", entity="anonymous", dir=args.log_path, config=args,
                            mode=mode, settings=wandb.Settings(_disable_stats=True, _disable_meta=True))
            args.runID = run.id
            define_metrics(run)
            if args.scheduler:
                run.define_metric("lr", summary="last")
            run.log({"jobID":int(jobID)})
            args.epoch_offset = 0
    else:
        run=None

    transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=0.406,
                                        std=0.225),
                    ]
                )

    train_dataset = MNIST(root='/gpfsdswork/dataset', train=True, download=False, transform=transform)
    
    if args.trainset_ratio<1:
        total_length = len(train_dataset)
        first_split = int(args.trainset_ratio * total_length)
        train_dataset = torch.utils.data.random_split(train_dataset,
                                    [first_split, total_length-first_split],
                                    generator=torch.Generator().manual_seed(42))[0]
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs, num_workers=args.num_workers,
                                                shuffle=True, drop_last=True)

    if args.do_eval:
        val_dataset = MNIST(root='/gpfsdswork/dataset', train=False, download=False, transform=transform)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.bs, num_workers=args.num_workers,
                                                shuffle=False, drop_last=True)

    if args.optim=='SGD':
        optimizer = torch.optim.SGD(network.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.optim=='Adam':
        optimizer = torch.optim.Adam(network.parameters(),
                                    lr=args.lr)

    if args.scheduler:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    else:
        scheduler = None

    tmp_mr = MetricsRecorder(hyperparams=vars(args))

    trainer = MCTrainer(classifier=classifier, optimizer=optimizer, scheduler=scheduler,
                    train_loader=train_loader, val_loader=val_loader,
                    device=device, compute_metrics=compute_metrics, wandb_run=run, args=args)

    trainer.train()

    if args.save_checkpoint:
        job_path = '/gpfswork/rech/ifs/ukm19fl/output/'+str(jobID)
        net_path = job_path + '_netweights.save'
        torch.save(network.state_dict(), net_path)

    run.finish()