import math
import time
import wandb
import functools
from tqdm import tqdm
from pprint import pformat
from datetime import datetime
from omegaconf import OmegaConf

import numpy as np
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch_geometric.loader import DataLoader

from parsers import Parser, get_config
from dataset.loader import MultiEpochsPYGDataLoader
from utils import (
    AverageMeter, validate, print_info, init_weights, load_generator, 
    ImbalancedSampler, build_augmentation_dataset, set_seed,
    dict_of_dicts_to_dict, flatten_dict, unflatten_dict
)
from utils.augment import augment_segments
from utils.loader import load_device, load_data, load_downstream_model

criterion = torch.nn.CrossEntropyLoss(reduction='none')
torch.set_num_threads(16)
# torch.multiprocessing.set_sharing_strategy('file_system')


def train(model, optimizer, data, input_hidden=False):
    model.train()
    optimizer.zero_grad()
    out, _ = model(data, input_hidden=input_hidden)
    loss = criterion(out[data.train_mask], data.y[data.train_mask]).mean()
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test(model, data, input_hidden=False):
    model.eval()
    logits, _ = model(data, input_hidden=input_hidden)
    pred = logits.argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.valid_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs

def main(config, task='NC', data_name='cora', seed=1, augment=True, num_workers=0, 
         sweep=False, prof_flag=False, extreme_mode=True, aug_orig=False, **kwargs):
    if sweep:
        wandb.run=None
        run = wandb.init(group=f'downstream_{task}_{data_name}_full_aug{str(augment)}')
        sweep_config = unflatten_dict(wandb.config)
        config = OmegaConf.merge(config, dict(sweep_config))

    set_seed(seed, extreme_mode=extreme_mode)

    lr = config.train.lr
    wdecay = config.train.wdecay
    epochs = config.train.epochs
    patience = config.train.patience

    start = config.augment.start
    iteration = config.augment.iteration

    device = load_device()
    if isinstance(device, list):
        device = f'cuda:{device[0]}'
    labeled_dataset = load_data(config, return_loader=False)
    data = labeled_dataset[0].to(device)

    # label_split_idx = labeled_dataset.splits
    # num_trained = len(label_split_idx[f"{data_name}-{task}-{split_dict['train'][0]}"])
    # num_trained_init = num_trained
    # steps = num_trained // batch_size + 1
    strategy = config.augment.strategy

    config.model.target = 'SimpleGCN'  # DGCNN not supported
    config.model.in_channels = labeled_dataset.num_features
    config.model.out_channels = labeled_dataset.num_classes
    model = load_downstream_model(config.model).to(device)
    if augment:
        generator = load_generator(device, path=config.augment.ckpt_path)
    init_weights(model, config.train.initw_name, init_gain=0.02)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wdecay)

    best_valid_acc = test_acc = 0
    counter = 0
    if augment:
        data_to_aug = data.clone()
    input_hidden = False
    for epoch in (p_bar := tqdm(range(epochs))):
        loss = train(model, optimizer, data, input_hidden)
        train_acc, valid_acc, tmp_test_acc = test(model, data, input_hidden)

        if augment and epoch >= start and epoch % iteration == 0:
            input_hidden = True if not aug_orig else False
            data = augment_segments(
                model, generator, data_to_aug, criterion, aug_orig=aug_orig, device=device, 
                sde_x_config=config.sde.x, sde_adj_config=config.sde.adj, **config.augment
            )
            data.part_list = data_to_aug.part_list
            data.train_mask, data.valid_mask, data.test_mask = \
                data_to_aug.train_mask, data_to_aug.valid_mask, data_to_aug.test_mask
            if strategy == 'accumulate':
                data_to_aug = data.clone()

        if valid_acc <= best_valid_acc:
            if epoch > 30:  # 30 
                counter += 1
                if counter > patience:
                    break
        else:
            best_valid_acc = valid_acc
            test_acc = tmp_test_acc
            best_epoch = epoch

        wandb.log({'loss': loss, 'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': tmp_test_acc})
        p_bar.set_description("Epoch: {e}/{es:4}. Loss: {l:.6f}. Train metric: {trm:.4f}. Validation metric: {vam:.4f}. Test metric: {tem:.4f}. ".format(
            e=epoch + 1,
            es=epochs,
            l=loss,
            trm=train_acc,
            vam=valid_acc,
            tem=tmp_test_acc,
        ))

    print('Finished training! Best validation results from epoch {}.'.format(best_epoch))
    print(f'valid_best: {best_valid_acc:.4f}')
    print(f'test_final: {test_acc:.4f}')
    wandb.log({'test_final-acc': test_acc})

    return {'acc': test_acc}

def run(prefix, task, data_name, config, augment, aug_orig=False, seed=0, 
        extreme_mode=True, sweep=False, sweep_id=None, prof_flag=False):
    datetime_now = datetime.now().strftime("%Y%m%d.%H%M%S")
    exp_name = f'{prefix}-{datetime_now}'
    if not sweep:
        _ = wandb.init(
            project='GraphDiff',
            group=f'downstream_{task}_{data_name}_full_aug{str(augment)}',
            name=exp_name, 
            config=dict(config),
        )
        results = main(config, task, data_name, seed, augment=augment, aug_orig=aug_orig,
                       prof_flag=prof_flag, extreme_mode=extreme_mode)
        return results
    else:
        
        params = {
            'model':{
                'dropout': {'values': [0.5]},  # [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
            },
            'train':{
                'lr': {'values': [1e-2]},  # [1e-4, 5e-4, 1e-3, 5e-3, 1e-2]
                'wdecay': {'values': [0]},  # [0., 1e-9, 1e-8, 1e-7, 1e-6, 1e-5]
            },
            'augment':{
                'out_steps': {'values': [1, 10, 50, 100, 500]},
                'topk': {'values': [None]},  # [None, 5, 10, 50, 128]
                'start': {'values': [10, 20, 50, 100]},
                'iteration': {'values': [1, 5, 10, 20]},
                'strategy': {'values': ['once', 'accumulate']},
                'snr': {'values': [0, 0.25, 0.5, 0.75, 0.9]},
                'scale_eps': {'values': [0, 0.25, 0.5, 0.75, 0.9]},
                'perturb_ratio': {'values': [0., 1e-8, 5e-8, 1e-7, 5e-7, 1e-6]},
                'n_steps': {'values': [1, 5, 10, 20]},
                # what to aug
                'aug_x': {'values': [False]},
                'aug_adj': {'values': [True]},
                'cutoff': {'values': [None, 0.1, 0.3, 0.5, 0.7, 0.9]},
                # infonce
                'infonce': {'values': [False]},
                # 'n_negative': {'values': [1, 10, 50, 100]},
                # 'pooling': {'values': ['mean']},
            }
        }
        if not augment:
            params.pop('augment')
        params = flatten_dict(params, stop_keys=['values'])

        sweep_configuration = {
            'method': 'bayes',
            'name': exp_name,
            'metric': {
                'goal': 'maximize', 
                'name': 'test_final-acc' if task == 'NC' else 'test_final-hits@100'
                },
            'parameters': params,
            # 'run_cap': 1000,
        }
        _sweep_id = wandb.sweep(sweep=sweep_configuration, project='GraphDiff')
        sweep_id = sweep_id if sweep_id is not None else _sweep_id
        print(f'sweep_id: {sweep_id}')
        wandb.agent(sweep_id, count=10000, 
                    function=functools.partial(main, config, task, data_name, seed, augment=augment, 
                                               sweep=sweep, extreme_mode=extreme_mode, aug_orig=aug_orig))
        return None

if __name__ == '__main__':
    # TODO: update argument; implement GNN for NC and LP 
    # args = get_args()
    # config = load_arguments_from_yaml(f'configures/{args.dataset}.yaml')
    args, unknown = Parser().parse()
    cli = OmegaConf.from_dotlist(unknown)
    config = get_config(args.config, args.seed)
    config = OmegaConf.merge(config, cli)
    if args.orig_feat:
        config.data.feature_types = None 
        config.data.stru_feat_principle = None
    if args.full_subgraph:
        config.data.max_node_num = None

    print(pformat(dict(config)))
    for k in config.keys():
        for arg, value in config[k].items():
            setattr(args, arg, value)

    task = config.data.task
    data_name = config.data.data_name
    metric_list = ['acc']

    if args.trails > 0:
        results = {m: [] for m in metric_list}
        for i in range(args.trails):
            temp_results = run(args.prefix, task, data_name, config,
                               args.augment, seed=i, extreme_mode=True)
            for m in metric_list:
                results[m].append(temp_results[m])
        for m in metric_list:
            print(f'{m}: {np.mean(results[m])} +/- {np.std(results[m])}')
    else:
        results = run(args.prefix, task, data_name, config, args.augment, args.aug_orig,
                      seed=args.seed, extreme_mode=True, sweep=args.sweep, sweep_id=args.sweep_id,
                      prof_flag=args.profile)