import torch
import random
import numpy as np
import math

from models.scorenet import ScoreNetworkAttnFix
from models.scorenet_node import ScoreNetworkNodeFix, ScoreNetworkNodeAttn

from utils.graph_utils import quantize, quantize_mol, quantize_mol_tensor, node_flags, mask_x


def load_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    return seed


def load_device(gpu):
    if isinstance(gpu, int):
        gpu = str(gpu)
    device = [int(g) for g in gpu.split(',')]
    return device   # list of integers


def load_model(params):
    params_ = params.copy()
    model_type = params_.pop('model_type', None)

    if model_type in ['attn', 'seq']:
        model = ScoreNetworkAttnFix(**params_)
    elif model_type == 'gcnX':
        model = ScoreNetworkNodeFix(**params_)
    elif model_type == 'attnX':
        model = ScoreNetworkNodeAttn(**params_)
    else:
        raise ValueError(f"Model Name <{model_type}> is Unknown")

    return model


def load_model_optimizer(params, config_train, device):
    if len(device) == 1:
        model = load_model(params)
    else:
        model = torch.nn.DataParallel(load_model(params), device_ids=device)
    model = model.to(f'cuda:{device[0]}')
    optimizer = torch.optim.Adam(model.parameters(), lr=config_train.lr, 
                                 weight_decay=config_train.weight_decay)
    scheduler = None
    if config_train.lr_schedule:
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config_train.lr_decay)
    
    return model, optimizer, scheduler


def load_optimizer(model, config_train):    # for continued training
    optimizer = torch.optim.Adam(model.parameters(), lr=config_train.lr, 
                                 weight_decay=config_train.weight_decay)
    scheduler = None
    if config_train.lr_schedule:
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config_train.lr_decay)
    
    return optimizer, scheduler


def load_data(config, get_graph_list=False):
    assert config.data.data in ['QM9', 'ZINC250k']
    from utils.data_loader_transform_mol import dataloader
    return dataloader(config, get_graph_list)


def load_prop_data(config):
    assert config.data.data == 'ZINC250k'
    from utils.data_loader_transform_mol import dataloader
    return dataloader(config, get_graph_list=False, prop=config.train.prop)


def load_low_data(config, get_graph_list=False):
    assert config.data.data == 'ZINC250k'
    from utils.data_loader_transform_mol_low import dataloader
    return dataloader(config, get_graph_list, protein=config.train.low_protein)


def load_batch(batch, device):
    x_b = batch[0].to(f'cuda:{device[0]}')
    adj_b = batch[1].to(f'cuda:{device[0]}')

    return x_b, adj_b


def load_sample(x, samples, check_sample, mol=False):
    if mol:
        if len(samples.shape) == 4:
            samples_out = quantize_mol_tensor(adjs=samples)
        else:
            samples_out = quantize_mol(adjs=samples)
    else:
        samples_out = quantize(adjs=samples)

    if check_sample:
        x = mask_x(x, node_flags(samples_out))
        xa = x.argmax(-1)
        sa = samples_out.sum(-1).to(torch.long)
        import pdb; pdb.set_trace()

    return samples_out


def load_sde(config_sde):
    from sde import VPSDE, VESDE, subVPSDE
    sde_type = config_sde.type
    beta_min = config_sde.beta_min
    beta_max = config_sde.beta_max
    num_scales = config_sde.num_scales

    if sde_type == 'VP':
        sde = VPSDE(beta_min=beta_min, beta_max=beta_max, N=num_scales)
    elif sde_type == 'subVP':
        sde = subVPSDE(beta_min=beta_min, beta_max=beta_max, N=num_scales)
    elif sde_type == 'VE':
        sde = VESDE(sigma_min=beta_min, sigma_max=beta_max, N=num_scales)
    else:
        raise NotImplementedError(f"SDE class {sde_type} not yet supported.")

    return sde


def load_loss_fn(config):
    reduce_mean = False
    sde_x = load_sde(config.sde.x)
    sde_adj = load_sde(config.sde.adj)
    from losses import get_sde_loss_fn
    get_loss_fn = get_sde_loss_fn

    loss_fn = get_loss_fn(sde_x, sde_adj, train=True, reduce_mean=reduce_mean, continuous=True,
                          likelihood_weighting=False, eps=config.train.eps, use_flags=config.train.flags)

    return loss_fn


def load_sampling_fn(config_train, config_module, config_sample, device, snr=None, scale_eps=None):
    if snr is None:
        snr = config_module.snr
    if scale_eps is None:
        scale_eps = config_module.scale_eps

    sde_x = load_sde(config_train.sde.x)
    sde_adj = load_sde(config_train.sde.adj)

    p_node_num = config_train.data.max_node_num

    if config_train.module != 'S':
        for i in range(len(config_train.module)):
            max_node_num = p_node_num
            p_node_num = math.ceil( max_node_num * config_train.proj.size[i] )

    max_node_num = p_node_num
    from predictor import get_pc_sampler
    get_sampler = get_pc_sampler

    shape_x = (config_sample.num_samples, max_node_num, config_train.data.max_feat_num)
    shape_adj = (config_sample.num_samples, max_node_num, max_node_num)

    sampling_fn = get_sampler(sde_x=sde_x, sde_adj=sde_adj, shape_x=shape_x, shape_adj=shape_adj, 
                              predictor=config_module.predictor, corrector=config_module.corrector,
                              snr=snr, scale_eps=scale_eps, 
                              n_steps=config_module.n_steps, 
                              probability_flow=config_sample.probability_flow, 
                              continuous=True, denoise=config_sample.noise_removal, 
                              eps=config_sample.eps, device=f'cuda:{device[0]}', ood=config_sample.ood)

    return sampling_fn


def load_sampling_fn_conditional(config_train, config_module, config_classifier, config_sample,
                                 configc, device, logger=None, snr=None, scale_eps=None):
    if snr is None:
        snr = config_module.snr
    if scale_eps is None:
        scale_eps = config_module.scale_eps

    sde_x = load_sde(config_train.sde.x)
    sde_adj = load_sde(config_train.sde.adj)
    
    from predictor import get_pc_sampler_conditional
    get_sampler = get_pc_sampler_conditional
    
    shape_x = (config_sample.num_samples, config_train.data.max_node_num, config_train.data.max_feat_num)
    shape_adj = (config_sample.num_samples, config_train.data.max_node_num, config_train.data.max_node_num)

    sampling_fn = get_sampler(sde_x=sde_x, sde_adj=sde_adj, shape_x=shape_x, shape_adj=shape_adj,
                              predictor=config_module.predictor, corrector=config_module.corrector,
                              time_dep=config_classifier.time_dep,
                              weight_x=config_classifier.weight_x, weight_adj=config_classifier.weight_adj,
                              weight_clamp=config_classifier.weight_clamp,
                              snr=snr, scale_eps=scale_eps, n_steps=config_module.n_steps,
                              probability_flow=config_sample.probability_flow,
                              continuous=True, denoise=config_sample.noise_removal,
                              eps=config_sample.eps, device=f'cuda:{device[0]}', ood=config_sample.ood,
                              regress=('Regressor' in configc.model.model),
                              weight_scheduling=config_sample.weight_scheduling,
                              ood_scheduling=config_sample.ood_scheduling,
                              logger=logger)
    return sampling_fn


def load_model_params(config):
    config_x = config.model.x
    config_adj = config.model.adj
    max_feat_num = config.data.max_feat_num
    p_node_num = config.data.max_node_num

    if config_x.type == 'attnX':
        params_x = {'model_type': config_x.type, 'max_feat_num': max_feat_num, 'depth': config_x.depth, 
                    'nhid': config_x.nhid, 'num_linears': config_adj.num_linears,
                    'c_init': config_adj.c_init, 'c_hid': config_adj.c_hid, 'c_final': config_adj.c_final, 
                    'adim': config_adj.adim, 'num_heads': config_adj.num_heads, 'conv':config_adj.conv}
    else:
        params_x = {'model_type': config_x.type, 'max_feat_num': max_feat_num, 'depth': config_x.depth, 
                    'nhid': config_x.nhid}
    params_adj = {'model_type': config_adj.type, 'max_feat_num': max_feat_num, 'max_node_num': p_node_num, 
                  'nhid': config_adj.nhid, 'num_layers': config_adj.num_layers, 'num_linears': config_adj.num_linears, 
                  'c_init': config_adj.c_init, 'c_hid': config_adj.c_hid, 'c_final': config_adj.c_final, 
                  'adim': config_adj.adim, 'num_heads': config_adj.num_heads, 'conv':config_adj.conv}

    return params_x, params_adj


def config_match(config1, config2, norm_check=True):
    assert config1.data.data == config2.data.data, f'Data mismatch: {config1.data.data} {config2.data.data}'
    assert config1.data.init == config2.data.init, f'Init mistmatch: {config1.data.init} {config2.data.init}'
    assert config1.seed == config2.seed, f'Seed mismatch: {config1.seed} {config2.seed}'
    assert config1.proj.type == config2.proj.type, f'ProjType mismatch: {config1.proj.type} {config2.proj.type}'
    assert not len(config1.module) > len(config2.module)
    assert config1.proj.size == config2.proj.size[:len(config1.module)], f'Psize mismatch: {config1.proj.size} {config2.proj.size}'
    if norm_check:
        assert config1.proj.norm == config2.proj.norm[:len(config1.module)], f'Data mismatch: {config1.proj.norm} {config2.proj.norm}'


def load_ckpt(config, device):
    ckpt_dict = {}
    path = f'./checkpoints/{config.data.data}/S-{config.module["S"].ckpt}.pth'
    ckpt = torch.load(path, map_location=f'cuda:{device[0]}')
    print(f'{path} loaded')
    model_config = ckpt['model_config']
    ckpt_dict['S'] = {'config': model_config, 'params_x': ckpt['params_x'], 'x_state_dict': ckpt['x_state_dict'],
                      'params_adj': ckpt['params_adj'], 'adj_state_dict': ckpt['adj_state_dict']}
    return ckpt_dict


def load_model_from_ckpt(params, state_dict, device):
    model = load_model(params)
    if 'module.' in list(state_dict.keys())[0]:
        state_dict = {k[7:]: v for k, v in state_dict.items()}  # strip 'module.' at front; for DataParallel models
    model.load_state_dict(state_dict)
    if len(device) > 1:
        model = torch.nn.DataParallel(model, device_ids=device)
    model = model.to(f'cuda:{device[0]}')
    
    return model
