import os
import numpy as np
from sklearn.metrics import average_precision_score
from torch.optim.optimizer import Optimizer
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset
import torch
from torchvision import transforms

tfms = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


class CelebA(Dataset):
    def __init__(self, dataframe, folder_dir, target_id, transform=None, gender=None, target=None):
        self.dataframe = dataframe
        self.folder_dir = folder_dir
        self.target_id = target_id
        self.transform = transform
        self.file_names = dataframe.index
        # self.labels = np.concatenate(dataframe.labels.values).astype(float)
        self.labels = dataframe.to_numpy()
        gender_id = 20

        if gender is not None:
            if target is not None:
                # label_np = np.concatenate(dataframe.labels.values)
                label_np = dataframe.to_numpy()
                gender_idx = np.where(label_np[:, gender_id] == gender)[0]
                target_idx = np.where(label_np[:, target_id] == target)[0]
                idx = list(set(gender_idx) & set(target_idx))
                self.file_names = self.file_names[idx]
                self.labels = np.concatenate(dataframe.labels.values[idx]).astype(float)
            else:
                # label_np = np.concatenate(dataframe.labels.values)
                # print(dataframe.shape)
                label_np = dataframe.to_numpy()
                gender_idx = np.where(label_np[:, gender_id] == gender)
                self.file_names = self.file_names[gender_idx]
                # self.labels = np.concatenate(dataframe.labels.values[gender_idx]).astype(float)
                self.labels = dataframe.to_numpy()[gender_idx]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        image = Image.open(os.path.join(self.folder_dir, self.file_names[index]))
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
        return image, label[self.target_id]


def get_loader(df, data_path, target_id, batch_size, gender=None, target=None):
    dl = CelebA(df, data_path, target_id, transform=tfms, gender=gender, target=target)

    if "train" in data_path:
        dloader = torch.utils.data.DataLoader(
            dl, shuffle=True, batch_size=batch_size, num_workers=3, drop_last=True
        )
    else:
        dloader = torch.utils.data.DataLoader(
            dl, shuffle=False, batch_size=batch_size, num_workers=3
        )

    return dloader


def evaluate4pp(model, model_linear, dataloader):
    y_scores = []
    y_true = []
    with torch.no_grad():
        for i, (inputs, target) in enumerate(dataloader):
            inputs, target = inputs.cuda(), target.float().cuda()

            feat = model(inputs)
            pred = model_linear(feat).detach()

            y_scores.append(pred.data.cpu().numpy())
            y_true.append(target.data.cpu().numpy())

    y_scores = np.concatenate(y_scores)
    y_true = np.concatenate(y_true)
    ap = average_precision_score(y_true, y_scores)
    return ap, y_scores, y_true


# freeze and activate gradient w.r.t. parameters
def model_freeze(model):
    for param in model.parameters():
        param.requires_grad = False


def model_activate(model):
    for param in model.parameters():
        param.requires_grad = True


def matrix_evaluator(loss, x, y, model):
    def evaluator(v):
        hvp = hessian_vector_prodct(loss, x, y, model, v)
        return hvp

    return evaluator


def hessian_vector_prodct(loss, x, y, model, vector_to_optimize):
    # given a gradient vector and parameter with the same size, compute its input for CG: AX
    # need to re-compute the gradient
    prediction_loss = loss(model(x), y)
    partial_grad = torch.autograd.grad(
        prediction_loss, model.parameters(), create_graph=True
    )  # need to compute hessian
    flat_grad = torch.cat([g.contiguous().view(-1) for g in partial_grad])
    h = torch.sum(flat_grad * vector_to_optimize)
    hvp = torch.autograd.grad(h, model.parameters(), retain_graph=True)
    hvp_flat = torch.cat([g.contiguous().view(-1) for g in hvp])
    return hvp_flat


def cg_solve(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10, x_init=None):
    """
    Goal: Solve Ax=b equivalent to minimizing f(x) = 1/2 x^T A x - x^T b
    The problem can be sloved through CG solver
    Assumption: A is PSD, no damping term is used here (must be damped externally in f_Ax)
    Algorithm template from wikipedia
    Verbose mode works only with numpy
    """

    if type(b) == torch.Tensor:
        x = torch.zeros(b.shape[0]) if x_init is None else x_init
        x = x.to(b.device)
        if b.dtype == torch.float16:
            x = x.half()
        r = b - f_Ax(x)
        p = r.clone()
    elif type(b) == np.ndarray:
        x = np.zeros_like(b) if x_init is None else x_init
        r = b - f_Ax(x)
        p = r.copy()
    else:
        print("Type error in cg")

    fmtstr = "%10i %10.3g %10.3g %10.3g"
    titlestr = "%10s %10s %10s %10s"
    if verbose:
        print(titlestr % ("iter", "residual norm", "soln norm", "obj fn"))

    for i in range(cg_iters):
        if callback is not None:
            callback(x)
        if verbose:
            obj_fn = 0.5 * x.dot(f_Ax(x)) - 0.5 * b.dot(x)
            norm_x = torch.norm(x) if type(x) == torch.Tensor else np.linalg.norm(x)
            print(fmtstr % (i, r.dot(r), norm_x, obj_fn))

        rdotr = r.dot(r)
        Ap = f_Ax(p)
        alpha = rdotr / (p.dot(Ap))
        x = x + alpha * p
        r = r - alpha * Ap
        newrdotr = r.dot(r)
        beta = newrdotr / rdotr
        p = r + beta * p

        if newrdotr < residual_tol:
            # print("Early CG termination because the residual was small")
            break

    if callback is not None:
        callback(x)
    if verbose:
        obj_fn = 0.5 * x.dot(f_Ax(x)) - 0.5 * b.dot(x)
        norm_x = torch.norm(x) if type(x) == torch.Tensor else np.linalg.norm(x)
        print(fmtstr % (i, r.dot(r), norm_x, obj_fn))
    return x


def construct_b(loss_1, loss_2, model_1, model_2, kappa=1.0):
    # compute the b term without gradient
    partial_grad_1 = torch.autograd.grad(
        loss_1, model_1.parameters(), create_graph=False, retain_graph=False
    )
    partial_grad_2 = torch.autograd.grad(
        loss_2, model_2.parameters(), create_graph=False, retain_graph=False
    )

    flat_partial_grad_1 = torch.cat([g.contiguous().view(-1) for g in partial_grad_1])
    flat_partial_grad_2 = torch.cat([g.contiguous().view(-1) for g in partial_grad_2])

    flat_model_1 = torch.cat([g.contiguous().view(-1) for g in model_1.parameters()])
    flat_model_2 = torch.cat([g.contiguous().view(-1) for g in model_2.parameters()])

    gap = flat_model_1 - flat_model_2

    b_1 = flat_partial_grad_1 + kappa * gap
    b_2 = flat_partial_grad_2 - kappa * gap

    return b_1, b_2


def meta_grad_update(meta_grad, model, optimizer, flat_grad=False):
    """
    Given the gradient, step with the outer optimizer using the gradient.
    Assumed that the gradient is a tuple/list of size compatible with model.parameters()
    If flat_grad, then the gradient is a flattened vector
    """

    if flat_grad:
        offset = 0
        # grad = utils.to_device(grad, self.use_gpu)
        for p in model.parameters():
            this_grad = meta_grad[offset : offset + p.nelement()].view(p.size())
            p.grad.copy_(this_grad)
            offset += p.nelement()
    else:
        for i, p in enumerate(model.parameters()):
            p.grad = meta_grad[i]
    optimizer.step()


def train_implicit(
    fea,
    clf_0,
    clf_1,
    criterion,
    optimizer_fea,
    optimizer_clf_0,
    optimizer_clf_1,
    dataloader_0,
    dataloader_1,
    kappa=1e-3,
    n_epoch=20,
    max_inner=15,
):
    for epoch in range(n_epoch):
        print("Epoch=[{}/{}]".format(epoch + 1, n_epoch))

        fea.train()
        clf_0.train()
        clf_1.train()

        len_dataloader = min(len(dataloader_0), len(dataloader_1))
        len_dataloader = int(len_dataloader)
        data_iter_0 = iter(dataloader_0)
        data_iter_1 = iter(dataloader_1)

        for it in tqdm(range(len_dataloader)):
            batch_x_0, batch_y_0 = data_iter_0.next()
            batch_x_1, batch_y_1 = data_iter_1.next()
            batch_x_0, batch_y_0 = batch_x_0.cuda(), batch_y_0.float().cuda()
            batch_x_1, batch_y_1 = batch_x_1.cuda(), batch_y_1.float().cuda()

            if (epoch + 1) <= 2:
                if (it + 1) <= 100:
                    # print("training through multi-task")
                    ## at start, use multi-task training for enabling the stability in the ResNet
                    y_pred_0 = clf_0(fea(batch_x_0))
                    y_pred_1 = clf_1(fea(batch_x_1))
                    loss_0 = criterion(y_pred_0, batch_y_0)
                    loss_1 = criterion(y_pred_1, batch_y_1)
                    optimizer_fea.zero_grad()
                    optimizer_clf_0.zero_grad()
                    optimizer_clf_1.zero_grad()
                    loss_0.backward()
                    loss_1.backward()
                    optimizer_fea.step()
                    optimizer_clf_0.step()
                    optimizer_clf_1.step()
                else:
                    continue
            else:

                # The complete implicit algorithm

                # Step 1: freeze feature representation, inner_optimization
                model_freeze(fea)

                z_0 = fea(batch_x_0)
                z_1 = fea(batch_x_1)

                # inner_loop for obtaining h_{\epsilon}
                for _ in range(max_inner):
                    y_pred_0 = clf_0(z_0)
                    y_pred_1 = clf_1(z_1)

                    loss_0 = criterion(y_pred_0, batch_y_0)
                    loss_1 = criterion(y_pred_1, batch_y_1)

                    optimizer_clf_0.zero_grad()
                    loss_0.backward()
                    optimizer_clf_0.step()

                    optimizer_clf_1.zero_grad()
                    loss_1.backward()
                    optimizer_clf_1.step()

                # Step2: computing P_1 and P_2 (does not require the gradient of lambda)
                # clear gradient
                optimizer_clf_0.zero_grad()
                optimizer_clf_1.zero_grad()

                y_pred_0 = clf_0(z_0)
                y_pred_1 = clf_1(z_1)

                loss_0 = criterion(y_pred_0, batch_y_0)
                loss_1 = criterion(y_pred_1, batch_y_1)

                AX_0 = matrix_evaluator(criterion, z_0, batch_y_0, clf_0)
                AX_1 = matrix_evaluator(criterion, z_1, batch_y_1, clf_1)

                b_0, b_1 = construct_b(loss_0, loss_1, clf_0, clf_1, kappa=kappa)

                P_0 = cg_solve(AX_0, b_0, cg_iters=20)
                P_1 = cg_solve(AX_1, b_1, cg_iters=20)

                P_0.detach()
                P_1.detach()

                # Step 3: compute meta-gradient (gradient of the representation)
                model_activate(fea)
                optimizer_clf_0.zero_grad()
                optimizer_clf_1.zero_grad()
                optimizer_fea.zero_grad()

                z_0 = fea(batch_x_0)
                z_1 = fea(batch_x_1)
                y_pred_0 = clf_0(z_0)
                y_pred_1 = clf_1(z_1)

                loss_0 = criterion(y_pred_0, batch_y_0)
                loss_1 = criterion(y_pred_1, batch_y_1)

                partial_lam_0 = torch.autograd.grad(loss_0, fea.parameters(), retain_graph=True)
                partial_lam_1 = torch.autograd.grad(loss_1, fea.parameters(), retain_graph=True)

                partial_h_0 = torch.autograd.grad(
                    loss_0, clf_0.parameters(), create_graph=True, allow_unused=True
                )
                partial_h_1 = torch.autograd.grad(
                    loss_1, clf_1.parameters(), create_graph=True, allow_unused=True
                )

                flat_grad_0 = torch.cat([g.contiguous().view(-1) for g in partial_h_0])
                hessian_vector_0 = torch.sum(flat_grad_0 * P_0)
                joint_hessian_0 = torch.autograd.grad(hessian_vector_0, fea.parameters())

                flat_grad_1 = torch.cat([g.contiguous().view(-1) for g in partial_h_1])
                hessian_vector_1 = torch.sum(flat_grad_1 * P_1)
                joint_hessian_1 = torch.autograd.grad(hessian_vector_1, fea.parameters())

                # print(joint_hessian_0)
                # the original gradients are in the form of tuple, we need additional function to make a new tuple
                meta_gradient = make_meta_grad(
                    partial_lam_0, partial_lam_1, joint_hessian_0, joint_hessian_1
                )

                meta_grad_update(meta_gradient, fea, optimizer_fea)


def make_meta_grad(partial_lam_0, partial_lam_1, joint_hessian_0, joint_hessian_1):
    list_meta = []
    for i in range(len(partial_lam_0)):
        list_meta.append(
            partial_lam_0[i] + partial_lam_1[i] - joint_hessian_0[i] - joint_hessian_1[i]
        )
    return tuple(list_meta)


def evaluate_pp_implicit(fea, clf_0, clf_1, dataloader_0, dataloader_1):
    fea.eval()
    clf_0.eval()
    clf_1.eval()

    ap_0, y_hat_0, Y_test_0 = evaluate4pp(fea, clf_0, dataloader_0)
    ap_1, y_hat_1, Y_test_1 = evaluate4pp(fea, clf_1, dataloader_1)
    ap = (ap_0 + ap_1) / 2.0

    if len(np.where(y_hat_0 < 0.5)[0]) == 0:
        pp_00 = 0
    else:
        pp_00 = len(list(set(np.where(y_hat_0 < 0.5)[0]) & set(np.where(Y_test_0 == 0)[0]))) / len(
            np.where(y_hat_0 < 0.5)[0]
        )
    if len(np.where(y_hat_1 < 0.5)[0]) == 0:
        pp_10 = 0
    else:
        pp_10 = len(list(set(np.where(y_hat_1 < 0.5)[0]) & set(np.where(Y_test_1 == 0)[0]))) / len(
            np.where(y_hat_1 < 0.5)[0]
        )

    gap_0 = np.abs(pp_00 - pp_10)

    if len(np.where(y_hat_0 >= 0.5)[0]) == 0:
        pp_01 = 0
    else:
        pp_01 = len(list(set(np.where(y_hat_0 >= 0.5)[0]) & set(np.where(Y_test_0 == 1)[0]))) / len(
            np.where(y_hat_0 >= 0.5)[0]
        )
    if len(np.where(y_hat_1 >= 0.5)[0]) == 0:
        pp_11 = 0
    else:
        pp_11 = len(list(set(np.where(y_hat_1 >= 0.5)[0]) & set(np.where(Y_test_1 == 1)[0]))) / len(
            np.where(y_hat_1 >= 0.5)[0]
        )

    gap_1 = np.abs(pp_01 - pp_11)
    gap = (gap_0 + gap_1) / 2.0
    print("ap={:.5f}, gap={:.5f}".format(ap, gap))
    return ap, gap