import os
import sys
import json

import torch
import pandas as pd
import wandb

from torchvision.models.densenet import DenseNet
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms

from trainer import Trainer
from classifier import HEXClassificationSystem
from models import Net, get_transform
from pyHEXgraph import HEXGraph, fastHEXLayer
from CustomDataset import CustomDataset
from metrics import compute_metrics
from MetricsRecorder import MetricsRecorder

# from fvcore.nn import FlopCountAnalysis

def idx_to_state(idx, i2n, tm):
    node = i2n[str(idx)]
    state = tm[:, node]
    return state


def define_metrics(run):
    run.define_metric("job")
    run.define_metric("jobID", summary="last", step_metric="job")
    run.define_metric("epoch", summary="last")
    run.define_metric("val/base_perfect_acc", summary="max", step_metric="epoch")
    run.define_metric("val/hex_perfect_acc", summary="max", step_metric="epoch")
    run.define_metric("val/base_leaves_acc", summary="max", step_metric="epoch")
    run.define_metric("val/hex_leaves_acc", summary="max", step_metric="epoch")
    run.define_metric("val/base_validity", summary="max", step_metric="epoch")
    run.define_metric("val/hex_validity", summary="max", step_metric="epoch")

def run(args):

    # Setting seed
    torch.manual_seed(args.seed)

    jobID = args.job_ID
    device="cuda:0"

    # Load the HEX-graph from txt files
    hexg = HEXGraph(file_path="./hex")
    # Init the HEX-layer
    hexL = fastHEXLayer(hexg, beta=args.beta, loss_normalize=args.loss_normalization)
    hexL.set_device(device=device)

    # Load the network
    network = Net(name=args.model, num_classes=hexL.hexg.numV, pretrained=args.pretrained)
    network.to(device)
    classifier = HEXClassificationSystem(model=network, hexL=hexL)

    if args.log:
        if args.online:
            from wandb_co import wandb_connect
            # Connect to the WandB run
            wandb_connect()
            mode="online"

        else:
            mode="offline"
        
        if args.load_checkpoint: # Load checkpoint if needed
            run = wandb.init(project="Cifar100", entity="anonymous", id=args.load_run_ID, resume="must", dir=args.log_path,
                            mode=mode, settings=wandb.Settings(_disable_stats=True, _disable_meta=True))
            args.lr = float(run.config.get("lr"))
            if args.load_job_ID is None:
                loadJobID = run.summary.get("jobID")['last']
            else:
                loadJobID = args.load_job_ID
            load_path = '/gpfswork/rech/ifs/ukm19fl/output/' + str(loadJobID) + '_netweights.save'
            network.load_state_dict(torch.load(load_path, map_location='cpu'))
            define_metrics(run)
            if args.scheduler:
                args.lr = float(run.summary.get("lr")['last'])
                run.define_metric("lr", summary="last")
            run.log({"jobID":int(jobID)})
            args.epoch_offset = int(run.summary.get("epoch")['last'])+1
            hyperparams = run.config

        else:
            # Compute the model size and the number of FLOPS
            args.model_size = sum(p.numel() for p in network.parameters() if p.requires_grad)
            # dummy_input = torch.ones((8, 3, 224, 224))
            # flops = FlopCountAnalysis(network, dummy_input)
            # hyperparams['flops'] = flops.total()
            
            # Launch wandb run and define metrics
            run = wandb.init(project="Cifar100", entity="anonymous", dir=args.log_path, config=args,
                            mode=mode, settings=wandb.Settings(_disable_stats=True, _disable_meta=True))
            args.runID = run.id
            define_metrics(run)
            if args.scheduler:
                run.define_metric("lr", summary="last")
            run.log({"jobID":int(jobID)})
            args.epoch_offset = 0
    else:
        run=None


    with open("./i2n.json", 'r') as f:
        i2n = json.load(f)
    train_dataset = CIFAR100(root='/gpfsdswork/dataset',
                            train=True,
                            transform=get_transform(name=args.model, train=True),
                            target_transform=lambda idx: idx_to_state(idx, i2n, hexg.getTM()),
                            download=False)
    
    if args.trainset_ratio:
        total_length = len(train_dataset)
        first_split = int(args.trainset_ratio * total_length)
        train_dataset = torch.utils.data.random_split(train_dataset,
                                    [first_split, total_length-first_split],
                                    generator=torch.Generator().manual_seed(42))[0]
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs, num_workers=args.num_workers,
                                                shuffle=True, drop_last=True)

    if args.do_eval:
        val_dataset = CIFAR100(root='/gpfsdswork/dataset',
                            train=False,
                            transform=get_transform(name=args.model, train=False),
                            target_transform=lambda idx: idx_to_state(idx, i2n, hexg.getTM()),
                            download=False)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.bs, num_workers=args.num_workers,
                                                shuffle=False, drop_last=True)

    if args.optim=='SGD':
        optimizer = torch.optim.SGD(network.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.optim=='Adam':
        optimizer = torch.optim.Adam(network.parameters(),
                                    lr=args.lr)

    if args.scheduler:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
    else:
        scheduler = None

    tmp_mr = MetricsRecorder(hyperparams=vars(args))

    trainer = Trainer(classifier=classifier, optimizer=optimizer, scheduler=scheduler,
                    train_loader=train_loader, val_loader=val_loader,
                    device=device, compute_metrics=compute_metrics, wandb_run=run, args=args)

    trainer.train()

    if args.save_checkpoint:
        job_path = '/gpfswork/rech/ifs/ukm19fl/output/'+str(jobID)
        net_path = job_path + '_netweights.save'
        torch.save(network.state_dict(), net_path)

    run.finish()