import sys
import os

import numpy as np

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

from ShortestPathDataset import ShortestPathDataset
from trainer import Trainer
from classifier import GridClassificationSystem
from metrics import compute_metrics
from MetricsRecorder import MetricsRecorder
from models import Net

# append the path of the parent directory
sys.path.append("..")
sys.path.append(os.path.join(sys.path[0], '..' , 'SPL', 'grids'))
sys.path.append(os.path.join(sys.path[0], '..' ,'SPL','grids', 'pypsdd'))
from compute_mpe import CircuitMPE

import wandb

def define_metrics(run):
    run.define_metric("job")
    run.define_metric("jobID", summary="last", step_metric="job")
    run.define_metric("epoch", summary="last")
    run.define_metric("val/base_perfect_acc", summary="max", step_metric="epoch")
    run.define_metric("val/hex_perfect_acc", summary="max", step_metric="epoch")
    # run.define_metric("val/base_validity", summary="max", step_metric="epoch")
    # run.define_metric("val/hex_validity", summary="max", step_metric="epoch")

def run(args):

    print("Test of Warcraft Shortest Path classification")
    # Load the network and set device

    torch.manual_seed(args.seed)

    device="cuda:0"
    network = Net(name=args.model, num_classes=2*(args.grid_size-1)*args.grid_size, pretrained=args.pretrained)
    network.to(device)
    jobID = args.job_ID

    folder_path = "./data/{0}x{0}/".format(args.grid_size)
    sdd_path = "./data/{0}x{0}/{0}x{0}.sdd".format(args.grid_size)
    vtree_path = "./data/{0}x{0}/{0}x{0}.vtree".format(args.grid_size)
    cmpe = CircuitMPE(vtree_path, sdd_path)
    classifier = GridClassificationSystem(model=network, n=args.grid_size, circuit=cmpe, beta=args.beta)

    if args.log:
        if args.online:
            from wandb_co import wandb_connect
            # Connect to the WandB run
            wandb_connect()
            mode="online"

        else:
            mode="offline"
        
        if args.load_checkpoint: # Load checkpoint if needed
            run = wandb.init(project="WSP-diag", entity="anonymous", id=args.load_run_ID, dir=args.log_path, resume="must",
                            mode=mode, settings=wandb.Settings(_disable_stats=True, _disable_meta=True))
            args.lr = float(run.config.get("lr"))
            if args.load_job_ID is None:
                loadJobID = run.summary.get("jobID")['last']
            else:
                loadJobID = args.load_job_ID
            load_path = '/gpfswork/rech/ifs/ukm19fl/output/' + str(loadJobID) + '_netweights.save'
            network.load_state_dict(torch.load(load_path, map_location='cpu'))
            define_metrics(run)
            if args.scheduler:
                args.lr = float(run.summary.get("lr")['last'])
                run.define_metric("lr", summary="last")
            run.log({"jobID":int(jobID)})
            args.epoch_offset = int(run.summary.get("epoch")['last'])+1
            hyperparams = run.config

        else:
            # Compute the model size and the number of FLOPS
            args.model_size = sum(p.numel() for p in network.parameters() if p.requires_grad)
            
            # Launch wandb run and define metrics
            run = wandb.init(project="WSP-diag", entity="anonymous", dir=args.log_path, config=args,
                            mode=mode, settings=wandb.Settings(_disable_stats=True, _disable_meta=True))
            args.runID = run.id
            define_metrics(run)
            if args.scheduler:
                run.define_metric("lr", summary="last")
            run.log({"jobID":int(jobID)})
            args.epoch_offset = 0
    else:
        run=None

    train_maps = np.load(folder_path+"train_maps.npy")
    train_targets = np.load(folder_path+"train_diag_labels.npy")
    mean, std = (
        np.mean(train_maps, axis=(0, 2, 3), keepdims=True),
        np.std(train_maps, axis=(0, 2, 3), keepdims=True),
      )
    transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=0.406,
                                        std=0.225),
                    ]
                )

    train_dataset = ShortestPathDataset(maps=train_maps, targets=train_targets, transform=transform)
    
    if args.trainset_ratio<1:
        total_length = len(train_dataset)
        first_split = int(args.trainset_ratio * total_length)
        train_dataset = torch.utils.data.random_split(train_dataset,
                                    [first_split, total_length-first_split],
                                    generator=torch.Generator().manual_seed(42))[0]
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs, num_workers=args.num_workers,
                                                shuffle=True, drop_last=True)

    if args.do_eval:
        test_maps = np.load("./data/12x12/test_maps.npy")
        test_targets = np.load("./data/12x12/test_diag_labels.npy")
        val_dataset = ShortestPathDataset(maps=test_maps, targets=test_targets, transform=transform)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.bs, num_workers=args.num_workers,
                                                shuffle=False, drop_last=True)

    if args.optim=='SGD':
        optimizer = torch.optim.SGD(network.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.optim=='Adam':
        optimizer = torch.optim.Adam(network.parameters(),
                                    lr=args.lr)

    if args.scheduler:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    else:
        scheduler = None

    tmp_mr = MetricsRecorder(hyperparams=vars(args))

    trainer = Trainer(classifier=classifier, optimizer=optimizer, scheduler=scheduler,
                    train_loader=train_loader, val_loader=val_loader,
                    device=device, compute_metrics=compute_metrics, wandb_run=run, args=args)

    trainer.train()

    if args.save_checkpoint:
        job_path = '/gpfswork/rech/ifs/ukm19fl/output/'+str(jobID)
        net_path = job_path + '_netweights.save'
        torch.save(network.state_dict(), net_path)

    run.finish()