import math
import torch
import numpy as np
from scipy.stats import stats
from utils import test_xk


def train_1cell(model, loader, optimizer, device, cfg):
    model.train()
    total_loss = 0
    tot = 0
    for batch in loader:
        batch = batch.to(device)
        aug_batch = batch.clone()

        optimizer.zero_grad()
        embeddings,sample_scores, y = model(batch)
        aug_embedding = embeddings.clone()
        lam = np.random.beta(1, 1)
        cos_sim = torch.mm(aug_embedding, aug_embedding.T)
        cos_sim.fill_diagonal_(0)
        cos_sim = (cos_sim - cos_sim.min()) / (cos_sim.max() - cos_sim.min() + 1e-6)
        y_diff = torch.cdist(y.view(-1, 1), y.view(-1, 1), p=1)
        y_sim = torch.exp(-y_diff)
        final_sim = y_sim * cos_sim
        final_sim = final_sim / (final_sim.sum(dim=1, keepdim=True) + 1e-6)
        shuffled_indices = torch.multinomial(final_sim, num_samples=1, replacement=True).squeeze()
        shuffled_embedding = aug_embedding[shuffled_indices]
        mixed_embedding = lam * aug_embedding + (1 - lam) * shuffled_embedding
        aug_batch.graph_emb  = mixed_embedding
        aug_scores, _ = model.post_mp(aug_batch,aug=True)
        sample_scores_shuffled = sample_scores[shuffled_indices]
        y_shuffled = y[shuffled_indices]

        sample_scores = sample_scores.squeeze()
        if sample_scores.ndim == 0:
            continue
        if cfg.dag.compare.do_limit:
            n_max_pairs = int(cfg.dag.compare.max_compare_ratio * len(batch))
        else:
            n_max_pairs = math.inf
        y = y.cpu().detach().numpy()
        acc_diff = y[:, None] - y
        acc_abs_diff_matrix = np.triu(np.abs(acc_diff), 1)
        ex_thresh_inds = np.where(acc_abs_diff_matrix > 0.0)
        ex_thresh_num = len(ex_thresh_inds[0])

        if ex_thresh_num > n_max_pairs:
            keep_inds = np.random.choice(np.arange(ex_thresh_num), n_max_pairs, replace=False)
            ex_thresh_inds = (ex_thresh_inds[0][keep_inds], ex_thresh_inds[1][keep_inds])

        better_labels = (acc_diff > 0)[ex_thresh_inds]
        n_diff_pairs = len(better_labels)

        s_1 = sample_scores[ex_thresh_inds[1]]
        s_2 = sample_scores[ex_thresh_inds[0]]

        better_pm = 2 * s_1.new(np.array(better_labels, dtype=np.float32)) - 1
        zero_, margin = s_1.new([0.0]), s_1.new([cfg.dag.compare.margin])

        loss = torch.mean(torch.max(zero_, margin - better_pm * (s_2 - s_1)))

        aug_loss = torch.zeros(1).to(device)
        for i in range(len(y)):
            pm = y[i]-y_shuffled[i]
            if pm > 0:
                s_gap1 = sample_scores[i]-aug_scores[i]
                s_gap2 = aug_scores[i]-sample_scores_shuffled[i]
            else:
                s_gap1 = sample_scores_shuffled[i] - aug_scores[i]
                s_gap2 = aug_scores[i] - sample_scores[i]
            aug_loss =  aug_loss + torch.mean(torch.max(zero_, margin - s_gap1)) + torch.mean(torch.max(zero_, margin - s_gap2))


        loss = loss + 0.2*aug_loss/len(y)
        loss.backward()
        if cfg.optim.clip_grad_norm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.optim.clip_grad_norm_value)
        optimizer.step()
        total_loss += float(loss) * n_diff_pairs
        tot += n_diff_pairs
    train_loss = total_loss / tot
    return train_loss


@torch.no_grad()
def eval_1cell(model, loader, device):
    model.eval()

    all_scores = []
    true_accs = []

    for batch in loader:
        batch = batch.to(device)
        _,output, y = model(batch)

        all_scores.extend(output.squeeze().cpu().tolist())
        true_accs.extend(y.squeeze().cpu().tolist())

    kt = stats.kendalltau(true_accs, all_scores).correlation
    sp = stats.spearmanr(true_accs, all_scores).correlation
    lc = np.corrcoef(true_accs, all_scores)[0, 1]
    pak = test_xk(true_accs, all_scores)
    return {"kt": kt, "sp": sp, "pak": pak,"lc":lc}


def train_2cell(model, loader, optimizer, device, cfg):
    model.train()
    total_loss = 0
    tot = 0
    for batch1, batch2 in loader:
        batch1 = batch1.to(device)
        batch2 = batch2.to(device)

        aug_batch1 = batch1.clone()
        aug_batch2 = batch2.clone()

        optimizer.zero_grad()
        embeddings1, embeddings2,sample_scores, y = model(batch1, batch2)
        aug_embedding1, aug_embedding2 = embeddings1.clone(), embeddings2.clone()
        lam = np.random.beta(1,1)
        cos_sim1,cos_sim2 = torch.mm(aug_embedding1,aug_embedding1.T),torch.mm(aug_embedding2,aug_embedding2.T)
        cos_sim1.fill_diagonal_(0)
        cos_sim2.fill_diagonal_(0)
        cos_sim1 = (cos_sim1-cos_sim1.min()) / (cos_sim1.max()-cos_sim1.min()+ 1e-6)
        cos_sim2 = (cos_sim2 - cos_sim2.min()) / (cos_sim2.max() - cos_sim2.min() + 1e-6)
        y_diff = torch.cdist(y.view(-1,1),y.view(-1,1),p=1)
        y_sim = torch.exp(-y_diff)

        final_sim= cos_sim1*cos_sim2*y_sim
        final_sim = final_sim/ (final_sim.sum(dim=1,keepdim=True)+1e-6)
        shuffle_indices = torch.multinomial(final_sim,num_samples=1,replacement=True).squeeze()
        shuffled_embedding1 = aug_embedding1[shuffle_indices]
        mixed_embedding1 = lam*aug_embedding1 + (1-lam) * shuffled_embedding1
        aug_batch1.graph_emb = mixed_embedding1
        shuffled_embedding2 = aug_embedding2[shuffle_indices]
        mixed_embedding2 = lam * aug_embedding2 + (1 - lam) * shuffled_embedding2
        aug_batch2.graph_emb = mixed_embedding2
        aug_scores,_ = model.post_mp(aug_batch1,aug_batch2,aug=True)
        sample_scores_shuffled = sample_scores[shuffle_indices]
        y_shuffled = y[shuffle_indices]
        sample_scores = sample_scores.squeeze()
        if cfg.dag.compare.do_limit:
            n_max_pairs = int(cfg.dag.compare.max_compare_ratio * len(batch1))
        else:
            n_max_pairs = math.inf
        y = y.cpu().detach().numpy()
        acc_diff = y[:, None] - y
        acc_abs_diff_matrix = np.triu(np.abs(acc_diff), 1)
        ex_thresh_inds = np.where(acc_abs_diff_matrix > 0.0)
        ex_thresh_num = len(ex_thresh_inds[0])

        if ex_thresh_num > n_max_pairs:
            keep_inds = np.random.choice(np.arange(ex_thresh_num), n_max_pairs, replace=False)
            ex_thresh_inds = (ex_thresh_inds[0][keep_inds], ex_thresh_inds[1][keep_inds])

        better_labels = (acc_diff > 0)[ex_thresh_inds]
        n_diff_pairs = len(better_labels)

        s_1 = sample_scores[ex_thresh_inds[1]]
        s_2 = sample_scores[ex_thresh_inds[0]]

        better_pm = 2 * s_1.new(np.array(better_labels, dtype=np.float32)) - 1
        zero_, margin = s_1.new([0.0]), s_1.new([cfg.dag.compare.margin])

        loss = torch.mean(torch.max(zero_, margin - better_pm * (s_2 - s_1)))

        aug_loss = torch.zeros(1).to(device)
        for i in range(len(y)):
            pm = y[i] - y_shuffled[i]
            if pm > 0:
                s_gap1 = sample_scores[i] - aug_scores[i]
                s_gap2 = aug_scores[i] - sample_scores_shuffled[i]
            else:
                s_gap1 = sample_scores_shuffled[i] - aug_scores[i]
                s_gap2 = aug_scores[i] - sample_scores[i]
            aug_loss = aug_loss + torch.mean(torch.max(zero_, margin - s_gap1)) + torch.mean(
                torch.max(zero_, margin - s_gap2))


        loss = loss + 0.3 * aug_loss / len(y)
        loss.backward()
        if cfg.optim.clip_grad_norm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.optim.clip_grad_norm_value)
        optimizer.step()
        total_loss += float(loss) * n_diff_pairs
        tot += n_diff_pairs
    train_loss = total_loss / tot
    return train_loss

@torch.no_grad()
def eval_2cell(model, loader, device):
    model.eval()

    all_scores = []
    true_accs = []

    for batch1, batch2 in loader:
        batch1 = batch1.to(device)
        batch2 = batch2.to(device)
        _,_,output, y = model(batch1, batch2)

        all_scores.extend(output.squeeze().cpu().tolist())
        true_accs.extend(y.squeeze().cpu().tolist())

    kt = stats.kendalltau(true_accs, all_scores).correlation
    sp = stats.spearmanr(true_accs, all_scores).correlation
    pak = test_xk(true_accs, all_scores)
    return {"kt": kt, "sp": sp, "pak": pak}


train_dict = {"1cell": (train_1cell, eval_1cell), "2cell": (train_2cell, eval_2cell)}
