import numpy as np
import pandas as pd

import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from copy import deepcopy
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split

import ot

from torch.autograd import grad
from torch.autograd import Variable
import random
from copy import deepcopy
from ot.backend import get_backend
from ot.utils import list_to_array

import os
"""

This code is adapted from or inspired by: 
* https://github.com/zwebzone/ggf/blob/main/utils.py
* https://github.com/Jae-Moo/Unbalanced-Optimal-Transport-Generative-Model
Credit to the original author.  
"""
def construct_path(target_path):
    """
    :param model_path:
    :param scaler_path:
    :param json_path:
    :return:
    """
    if not os.path.exists(target_path):
        os.makedirs(target_path)

def eval_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for step, t_data in enumerate(test_loader):
            x, y = t_data
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()
            s_pred = model.forward(x)
            pred = s_pred.argmax(dim=1, keepdim=True)
            correct += pred.eq(y.long().view_as(pred)).sum().item()
            total += len(y)
    model.train()
    return correct / total # round( 4)

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True






def initial_classfier(d, o):
    return nn.Sequential(
        nn.Linear(d, 128), nn.ReLU(),
        nn.Dropout(p=0.25),
        nn.BatchNorm1d(num_features=128),
        nn.Linear(128, 128), nn.ReLU(),
        nn.Dropout(p=0.25),
        nn.Linear(128, o)
    ).cuda()



def cal_entropy(s_pred):
    h = F.softmax(s_pred, dim=1) * F.log_softmax(s_pred, dim=1)
    return -1.0 * h.mean()







def get_pseudo_dataset(model, data, confidence_q=0.1):
    model.eval()
    new_data, new_label = [], []
    with torch.no_grad():
        pred = model.forward(data)
        new_data.append(data)
        new_label.append(pred)
        new_data = torch.cat(new_data, dim=0)
        new_label = torch.cat(new_label, dim=0)
        confidence = np.array(torch.Tensor.cpu(new_label.amax(dim=1) - new_label.amin(dim=1)))
        alpha = np.quantile(confidence, confidence_q)
        conf_index = np.array(np.argwhere(confidence >= alpha)[:, 0])
        pseudo_y = new_label.argmax(dim=1)
    model.train()
    return new_data[conf_index], pseudo_y[conf_index]


# ------------------------
# Select $f^{\star}$
# ------------------------
def select_phi(name):
    if name == 'linear':
        def phi(x):
            return x

    elif name == 'kl':
        def phi(x):
            return torch.exp(x)

    elif name == 'chi':
        def phi(x):
            y = F.relu(x + 2) - 2
            return 0.25 * y ** 2 + y

    elif name == 'softplus':
        def phi(x):
            return F.softplus(x)
    else:
        raise NotImplementedError

    return phi



def uni_semi_uot(a, b, M, reg_sink, reg_b, numItermax=1000, stopThr=1.0e-15, log=False):
    M, a, b = list_to_array(M, a, b)
    nx = get_backend(M, a, b)
    dim_a, dim_b = M.shape

    if log:
        log = {"err": [], "G": []}

    # weight normalized
    a = nx.ones(dim_a, type_as=M) / dim_a
    b = nx.ones(dim_b, type_as=M) / dim_b
    G = a[:, None] * b[None, :]

    sum_r = reg_b + reg_sink
    r2, r = reg_b / sum_r, reg_sink / sum_r
    K = nx.exp(-M / sum_r)
    Gprev = G
    for i in range(numItermax):
        Gprev = G
        # update of (G^T 1)_j
        G_col_sum = nx.sum(G, 0, keepdims=True)

        # update of v_j = (b_j / (G^T 1)_j) ** (lambda_b / lambda)
        v = (b[None, :] / G_col_sum) ** (r2)

        # update of u_i = a_i / (K * v)_i
        Kv = nx.sum(K * v, 1, keepdims=True)
        u = a[:, None] / Kv

        # update of G_ij = u_i * K_ij * v_j
        G = u * K * v

        err = nx.sqrt(nx.sum((G - Gprev) ** 2))
        if log:
            log["err"].append(err)
            log["G"].append(G)

        if err < stopThr:
            break
    m2 = nx.sum(G, 0)
    if log:
        linear_cost = nx.sum(G * M)
        kl_loss = nx.kl_div(m2, b, mass=True)
        log["cost"] = linear_cost + kl_loss
        return G, log
    else:
        return G



