import torch
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
from utils.graph_utils import node_flags, mask_x, mask_adjs, graphs_to_tensor, init_features, gen_noise
from models.classifier import get_classifier_fn, RegressorGCN
from utils.loader import load_sde


def graphs_to_dataloader_classifier(config, graph_list, labels):
    adjs_tensor = graphs_to_tensor(graph_list, config.data.max_node_num)
    x_tensor = init_features(config.data.init, adjs_tensor, config.data.max_feat_num)
    labels_tensor = torch.tensor(labels)
    train_ds = TensorDataset(x_tensor, adjs_tensor, labels_tensor)
    train_dl = DataLoader(train_ds, batch_size=config.data.batch_size, shuffle=True)
    return train_dl


def load_classifier_params(config):
    config_m = config.model
    params = {'model_type': config_m.model, 'max_node_num': config.data.max_node_num,
                'max_feat_num': config.data.max_feat_num, 'depth':config_m.depth, 
                'nhid': config_m.nhid, 'dropout': config_m.dropout, 'prop': config.train.prop}

    return params

def load_classifier(params):
    params_ = params.copy()
    model_type = params_.pop('model_type', None)
    assert model_type == 'RegressorGCN'
    model = RegressorGCN(**params_)

    return model


def load_classifier_optimizer(params, config_train, device):
    model = load_classifier(params).to(f'cuda:{device[0]}')
    optimizer = torch.optim.Adam(model.parameters(), lr=config_train.lr, #betas=(0.0, 0.999),
                                 weight_decay=config_train.weight_decay)
    scheduler = None
    if config_train.lr_schedule:
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config_train.lr_decay)
    
    return model, optimizer, scheduler


def load_classifier_batch(batch, device):
    x_b = batch[0].to(f'cuda:{device[0]}')
    adj_b = batch[1].to(f'cuda:{device[0]}')
    label_b = batch[2].unsqueeze(-1).to(f'cuda:{device[0]}')

    return x_b, adj_b, label_b


def load_classifier_loss_fn(config):
    # criterion = torch.nn.CrossEntropyLoss()
    criterion = torch.nn.BCELoss()
    sde_x = load_sde(config.sde.x)
    sde_adj = load_sde(config.sde.adj)
    eps = config.train.eps
    time_dep = config.model.time_dep

    def loss_fn(model, x, adj, labels):
        classifier_fn = get_classifier_fn(sde_adj, model, time_dep=time_dep)

        if config.train.flags:  #####
            flags = node_flags(adj)
        else:
            flags = None
        t = torch.rand(adj.shape[0], device=adj.device) * (sde_adj.T - eps) + eps

        z_x = gen_noise(x, flags, sym=False)
        mean_x, std_x = sde_x.marginal_prob(x, t)
        perturbed_x = mean_x + std_x[:, None, None] * z_x
        perturbed_x = mask_x(perturbed_x, flags)

        z_adj = gen_noise(adj, flags, sym=True)
        mean_adj, std_adj = sde_adj.marginal_prob(adj, t)

        perturbed_adj = mean_adj + std_adj[:, None, None] * z_adj
        perturbed_adj = mask_adjs(perturbed_adj, flags)

        pred = classifier_fn(perturbed_x, perturbed_adj, flags, t)
        loss = (pred - labels).pow(2).mean()

        with torch.no_grad():
            df = pd.DataFrame()
            df['pred'] = pred.cpu().detach().numpy().squeeze()
            df['labels'] = labels.cpu().detach().numpy().squeeze()
            spearman_corr = df.corr('spearman')['pred']['labels']
            pearson_corr = df.corr('pearson')['pred']['labels']
            
            print(f'pred: {pred.min().item():.4f}, {pred.mean().item():.4f}, {pred.max().item():.4f} | '
                  f'labels: {labels.min().item():.4f}, {labels.mean().item():.4f}, {labels.max().item():.4f} | '
                  f'S corr: {spearman_corr:.4f} | P corr: {pearson_corr:.4f}')

        return loss, spearman_corr, pearson_corr

    return loss_fn


def load_classifier_loss_fn_dlsm(config):
    sde_x = load_sde(config.sde.x)
    sde_adj = load_sde(config.sde.adj)
    eps = config.train.eps
    time_dep = config.model.time_dep
    reduce_op = lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
    # reduce_op = torch.mean

    def loss_fn(model, score_fn, optimizer, x, adj, labels, obj='X'):
        if obj == 'X': sde = sde_x
        else: sde = sde_adj
        classifier_fn = get_classifier_fn(sde, model, time_dep=time_dep)

        if config.train.flags:  #####
            flags = node_flags(adj)
        else:
            flags = None
        t = torch.rand(adj.shape[0], device=adj.device) * (sde_adj.T - eps) + eps

        z_x = gen_noise(x, flags, sym=False)
        mean_x, std_x = sde_x.marginal_prob(x, t)
        perturbed_x = mean_x + std_x[:, None, None] * z_x
        perturbed_x = mask_x(perturbed_x, flags)

        z_adj = gen_noise(adj, flags, sym=True)
        mean_adj, std_adj = sde_adj.marginal_prob(adj, t)
        perturbed_adj = mean_adj + std_adj[:, None, None] * z_adj
        perturbed_adj = mask_adjs(perturbed_adj, flags)

        score = score_fn(perturbed_x, perturbed_adj, flags)

        with torch.enable_grad():
            if obj == 'X':
                perturbed_x_para = torch.nn.Parameter(perturbed_x)
                F = classifier_fn(perturbed_x_para, perturbed_adj, flags, t).sum()
                F.backward()
                prop_score = perturbed_x_para.grad
                prop_score = mask_x(prop_score, flags)
                std = std_x
                z = z_x
            else:
                perturbed_adj_para = torch.nn.Parameter(perturbed_adj)
                F = classifier_fn(perturbed_x, perturbed_adj_para, flags, t).sum()
                F.backward()
                prop_score = perturbed_adj_para.grad
                prop_score = mask_adjs(prop_score, flags)
                std = std_adj
                z = z_adj

        optimizer.zero_grad()
        pred = classifier_fn(perturbed_x, perturbed_adj, flags, t)

        dlsm_loss = torch.square(prop_score * std[:, None, None] + score * std[:, None, None] + z)
        dlsm_loss = reduce_op(dlsm_loss.reshape(dlsm_loss.shape[0], -1), dim=-1).mean()

        mse_loss = (pred - labels).pow(2).mean() #* 10000
        mse_loss *= dlsm_loss.item() / mse_loss.item()  # automatic weighting

        loss = dlsm_loss + mse_loss

        with torch.no_grad():
            df = pd.DataFrame()
            df['pred'] = pred.cpu().detach().numpy().squeeze()
            df['labels'] = labels.cpu().detach().numpy().squeeze()
            spearman_corr = df.corr('spearman')['pred']['labels']
            pearson_corr = df.corr('pearson')['pred']['labels']
            
            print(f'[{obj}] pred: {pred.min().item():.4f}, {pred.mean().item():.4f}, {pred.max().item():.4f} | '
                  f'labels: {labels.min().item():.4f}, {labels.mean().item():.4f}, {labels.max().item():.4f} | '
                  f'S corr: {spearman_corr:.4f} | P corr: {pearson_corr:.4f}')

        return dlsm_loss, mse_loss, loss, spearman_corr, pearson_corr

    return loss_fn


def load_classifier_from_ckpt(params, state_dict, device):
    model = load_classifier(params)
    model.load_state_dict(state_dict)
    model = model.to(f'cuda:{device[0]}')

    return model


def load_classifier_ckpt(config, device):
    ckpt_dict = {}
    path = f'./checkpoints/{config.module["C"].data}/C-{config.module["C"].ckpt}.pth'
    ckpt = torch.load(path, map_location=f'cuda:{device[0]}')
    print(f'{path} loaded')
    if 'final_state_dicts' in ckpt:     # layer share
        ckpt_dict['C'] = {'config': ckpt['model_config'], 'params': ckpt['params'],
                          'state_dict': ckpt['state_dict'], 'final_state_dicts': ckpt['final_state_dicts']}
    elif 'state_dict' in ckpt:
        ckpt_dict['C'] = {'config': ckpt['model_config'], 'params': ckpt['params'], 'state_dict': ckpt['state_dict']}
    else:   # dlsm
        ckpt_dict['C'] = {'config': ckpt['model_config'], 'params': ckpt['params'],
                          'x_state_dict': ckpt['x_state_dict'], 'adj_state_dict': ckpt['adj_state_dict']}
    ckpt_dict['C']['config']['data']['data'] = config.data.data

    return ckpt_dict


def data_log(logger, config):
    logger.log(f'[{config.data.data}]   init={config.data.init} ({config.data.max_feat_num})   '
                f'seed={config.seed}   batch_size={config.data.batch_size}')


def sde_log(logger, config_sde):
    sde_x = config_sde.x
    sde_adj = config_sde.adj
    logger.log(f'(x:{sde_x.type})=({sde_x.beta_min:.2f}, {sde_x.beta_max:.2f}) N={sde_x.num_scales} ' 
                f'(adj:{sde_adj.type})=({sde_adj.beta_min:.2f}, {sde_adj.beta_max:.2f}) N={sde_adj.num_scales}')

def model_log(logger, config):
    config_m = config.model
    model_log = f'({config_m.model}): ' \
                f'depth={config_m.depth} nhid={config_m.nhid} ' \
                f'dropout={config_m.dropout} time_dep={config_m.time_dep}'
    logger.log(model_log)


def start_log(logger, config, is_train=True):
    if is_train:
        logger.log('-'*100)
        logger.log(f'{config.exp_name}')
    logger.log('-'*100)
    data_log(logger, config)
    logger.log('-'*100)


def train_log(logger, config):
    sde_log(logger, config.sde)
    model_log(logger, config)
    logger.log(f'EPOCHS={config.train.num_epochs} lr={config.train.lr} schedule={config.train.lr_schedule}')
    logger.log('-'*100)


def sample_log(logger, configc):
    # logger.log(f'time_dep={configc.time_dep}   weights={configc.weights}')
    # logger.log(f'time_dep={configc.time_dep}   weight={configc.weight}')
    logger.log(f'time_dep={configc.time_dep}   [X] weight={configc.weight_x}   [A] weight={configc.weight_adj}')
    logger.log('-'*100)
