import os
import sys
import pickle

import torch
import pandas as pd
import wandb

from trainer import Trainer
from classifier import HEXClassificationSystem
from pyHEXgraph import HEXGraph, fastHEXLayer
from metrics import compute_metrics
from MetricsRecorder import MetricsRecorder
from models.models import Net, get_transform
from CustomDataset import CustomDataset

def getPaths(args, dir_path='./ImageNet/'):
    paths = {}
    base_path = dir_path+str(args.nb_leaves)
    if args.pruning:
        base_path = base_path + "p"
    paths["config_path"] = base_path+"_config.json"
    paths["sdd_path"] = base_path+".sdd"
    paths["vtree_path"] = base_path+".vtree"
    paths["graph_path"] = base_path+"_graph.pkl"
    if not(args.assume_exclusive):
        base_path = base_path + "h"
    paths["file_path"] = base_path

    return paths

def define_metrics(run):
    run.define_metric("job")
    run.define_metric("jobID", summary="last", step_metric="job")
    run.define_metric("epoch", summary="last")
    run.define_metric("val/base_perfect_acc", summary="max", step_metric="epoch")
    run.define_metric("val/hex_perfect_acc", summary="max", step_metric="epoch")
    run.define_metric("val/base_leaves_acc", summary="max", step_metric="epoch")
    run.define_metric("val/hex_leaves_acc", summary="max", step_metric="epoch")
    run.define_metric("val/base_validity", summary="max", step_metric="epoch")
    run.define_metric("val/hex_validity", summary="max", step_metric="epoch")

def run(args):
    jobID = args.job_ID

    if args.log and args.online:
        from wandb_co import wandb_connect
        # Connect to the WandB run
        wandb_connect()
        mode="online"

    else:
        mode="offline"

    if args.load_checkpoint:
        if args.log:
            run = wandb.init(project="ImageNet", entity="anonymous", id=args.load_run_ID, resume="must", dir=args.log_path,
                            mode=mode, settings=wandb.Settings(_disable_stats=True, _disable_meta=True))
            args.model = str(run.config.get("model"))
            args.lr = float(run.config.get("lr"))
            args.seed = int(run.config.get("seed"))
            args.beta = float(run.config.get("beta"))
            args.clamping = bool(run.config.get("clamping"))
            args.trainset_ratio = float(run.config.get("trainset_ratio"))
            if args.load_job_ID is None:
                loadJobID = run.summary.get("jobID")['last']
            else:
                loadJobID = args.load_job_ID
            load_path = args.log_path + "/checkpoints/" + str(loadJobID) + '_netweights.save'
            define_metrics(run)
            if args.scheduler:
                args.lr = float(run.summary.get("lr")['last'])
                run.define_metric("lr", summary="last")
            run.log({"jobID":int(jobID)})
            args.epoch_offset = int(run.summary.get("epoch")['last'])+1
        else:
            raise Exception('Cannot load checkpoint without logging on wandb')

    # Setting seed
    torch.manual_seed(args.seed)
    device="cuda:0"
    paths = getPaths(args,
                    dir_path='./compilations/')

    if args.spl:
        from compute_mpe import CircuitMPE
        from classifier import CircuitClassificationSystem
        cmpe = CircuitMPE(paths["vtree_path"], paths["sdd_path"]) #add device

        # Load the network
        network = Net(name=args.model, num_classes=271, pretrained=args.pretrained)
        if args.load_checkpoint:
            network.load_state_dict(torch.load(load_path, map_location='cpu'))
        network.to(device)

        classifier = CircuitClassificationSystem(model=network, circuit=cmpe, beta=args.beta)


    else:
        # Load the HEX-graph from txt files
        hexg = HEXGraph(file_path=paths["file_path"])

        # Init the HEX-layer
        hexL = fastHEXLayer(hexg, beta=args.beta, loss_normalize=args.loss_normalization)
        hexL.set_device(device=device)


        # Load the network
        network = Net(name=args.model, num_classes=hexL.hexg.numV, pretrained=args.pretrained)
        if args.load_checkpoint:
            network.load_state_dict(torch.load(load_path, map_location='cpu'))
        network.to(device)

        classifier = HEXClassificationSystem(model=network, hexL=hexL)


    if args.log and (not args.load_checkpoint):
            # Compute the model size and the number of FLOPS
            args.model_size = sum(p.numel() for p in network.parameters() if p.requires_grad)
            # from fvcore.nn import FlopCountAnalysis
            # dummy_input = torch.ones((8, 3, 224, 224))
            # flops = FlopCountAnalysis(network, dummy_input)
            # args.flops = flops.total()
            
            # Launch wandb run and define metrics
            run = wandb.init(project="ImageNet", entity="anonymous", dir=args.log_path, config=args,
                            mode=mode, settings=wandb.Settings(_disable_stats=True, _disable_meta=True))
            args.runID = run.id
            define_metrics(run)
            if args.scheduler:
                run.define_metric("lr", summary="last")
            run.log({"jobID":int(jobID)})
            args.epoch_offset = 0

    train_dataset = CustomDataset(root="/gpfs/workdir/shared/datasets/ILVSRC2012/train",
                                config_path=paths["config_path"],
                                tm=hexg.getTM(),
                                transform=get_transform(name=args.model, train=True),
                                mapped=args.mapped)
    
    if args.trainset_ratio<1:
        total_length = len(train_dataset)
        first_split = int(args.trainset_ratio * total_length)
        train_dataset = torch.utils.data.random_split(train_dataset,
                                    [first_split, total_length-first_split],
                                    generator=torch.Generator().manual_seed(42))[0]
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs, num_workers=args.num_workers,
                                                shuffle=True, drop_last=True)

    if args.do_eval:
        val_dataset = CustomDataset(root="/gpfs/workdir/shared/datasets/ILVSRC2012/val",
                                config_path=paths["config_path"],
                                tm=hexg.getTM(),
                                transform=get_transform(name=args.model, train=False),
                                mapped=args.mapped)       
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.bs, num_workers=args.num_workers,
                                                shuffle=False, drop_last=True)

    if args.optim=='SGD':
        optimizer = torch.optim.SGD(network.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.optim=='Adam':
        optimizer = torch.optim.Adam(network.parameters(),
                                    lr=args.lr)

    if args.scheduler:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    else:
        scheduler = None

    tmp_mr = MetricsRecorder(hyperparams=vars(args))

    trainer = Trainer(classifier=classifier, optimizer=optimizer, scheduler=scheduler,
                    train_loader=train_loader, val_loader=val_loader,
                    device=device, compute_metrics=compute_metrics, wandb_run=run, args=args)

    trainer.train()

    if args.save_checkpoint:
        load_path = args.log_path + "/checkpoints/" + str(jobID) + '_netweights.save'
        torch.save(network.state_dict(), load_path)

    if args.log:
        run.finish()