import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import copy
import pandas as pd
import os
import datetime


def mixupsample(x, y, a, alpha=2.0):

    def _mix_up(alpha, x1, x2, y1, y2):
        length = min(len(x1), len(x2))
        x1 = x1[:length]
        x2 = x2[:length]
        y1 = y1[:length]
        y2 = y2[:length]

        # n_classes = y1.shape[1]
        bsz = len(x1)
        l = np.random.beta(alpha, alpha, [bsz, 1])
        if len(x1.shape) == 4:
            l_x = np.tile(l[..., None, None], (1, *x1.shape[1:]))
        else:
            l_x = np.tile(l, (1, *x1.shape[1:]))
        # l_y = np.tile(l, [1, n_classes])
        l_y = l.squeeze()

        # mixed_input = l * x + (1 - l) * x2
        mixed_x = l_x * x1 + (1 - l_x) * x2
        mixed_y = l_y * y1 + (1 - l_y) * y2

        return mixed_x, mixed_y

    fn = _mix_up

    all_mix_x, all_mix_y = [], []
    bs = len(x)
    # repeat until enough samples
    while sum(list(map(len, all_mix_x))) < bs:
        start_len = sum(list(map(len, all_mix_x)))
        s = np.random.random() <= 0.5  # self.hparams["LISA_p_sel"]
        # same label, mixup between attributes
        if s:
            for y_i in np.unique(y):
                mask = y == y_i
                x_i, y_i, a_i = x[mask], y[mask], a[mask]
                unique_a_is = np.unique(a_i)
                # # if there are multiple attributes, choose a random pair
                # a_i1, a_i2 = unique_a_is[np.randperm(len(unique_a_is))][:2]
                a_i1, a_i2 = unique_a_is[0], unique_a_is[1]
                mask2_1 = a_i == a_i1
                mask2_2 = a_i == a_i2
                all_mix_x_i, all_mix_y_i = fn(
                    alpha, x_i[mask2_1], x_i[mask2_2], y_i[mask2_1], y_i[mask2_2]
                )
                all_mix_x.append(all_mix_x_i)
                all_mix_y.append(all_mix_y_i)
        # same attribute, mixup between labels
        else:
            for a_i in np.unique(a):
                mask = a == a_i
                x_i, y_i = x[mask], y[mask]
                unique_y_is = np.unique(y)
                # # if there are multiple labels, choose a random pair
                # y_i1, y_i2 = unique_y_is[np.randperm(len(unique_y_is))][:2]
                y_i1, y_i2 = unique_y_is[0], unique_y_is[1]
                # mask2_1 = y_i[:, y_i1].squeeze().bool()
                # mask2_2 = y_i[:, y_i2].squeeze().bool()
                mask2_1 = y_i == y_i1
                mask2_2 = y_i == y_i2
                all_mix_x_i, all_mix_y_i = fn(
                    alpha, x_i[mask2_1], x_i[mask2_2], y_i[mask2_1], y_i[mask2_2]
                )
                all_mix_x.append(all_mix_x_i)
                all_mix_y.append(all_mix_y_i)

        end_len = sum(list(map(len, all_mix_x)))
        # each attribute only has one unique label
        if end_len == start_len:
            return x, y

    all_mix_x = np.concatenate(all_mix_x, axis=0)
    all_mix_y = np.concatenate(all_mix_y, axis=0)

    # shuffle the mixed samples
    all_mix_x = all_mix_x[np.random.permutation(len(all_mix_x))]
    all_mix_y = all_mix_y[np.random.permutation(len(all_mix_y))]

    return all_mix_x[:bs], all_mix_y[:bs]


def oversample(g, n_groups):
    group_counts = []
    for group_idx in range(n_groups):
        group_counts.append((g == group_idx).sum())
    resampled_idx = []
    for group_idx in range(n_groups):
        (idx,) = np.where(g == group_idx)
        if group_counts[group_idx] < max(group_counts):
            for _ in range(max(group_counts) // group_counts[group_idx]):
                resampled_idx.append(idx)
            resampled_idx.append(
                np.random.choice(
                    idx, max(group_counts) % group_counts[group_idx], replace=False
                )
            )
        else:
            resampled_idx.append(idx)
    resampled_idx = np.concatenate(resampled_idx)
    return resampled_idx


def undersample(g, n_groups):
    group_counts = []
    for group_idx in range(n_groups):
        group_counts.append((g == group_idx).sum())
    resampled_idx = []
    for group_idx in range(n_groups):
        (idx,) = np.where(g == group_idx)
        resampled_idx.append(np.random.choice(idx, min(group_counts), replace=False))
    resampled_idx = np.concatenate(resampled_idx)
    return resampled_idx


def groupdro_loss(yhat, y, gs, q):
    losses = F.binary_cross_entropy_with_logits(yhat, y, reduction="none")

    for g in np.unique(gs):
        idx_g = g == gs
        q[g] *= (1e-3 * losses[idx_g].mean()).exp().item()

    q /= q.sum()
    loss = 0
    for g in np.unique(gs):
        idx_g = g == gs
        loss += q[g] * losses[idx_g].mean()

    return loss, q


def load_emb(data_args, split):
    n, sc, ci, ai, y_task, a_task, model = (
        data_args["n"],
        data_args["sc"],
        data_args["ci"],
        data_args["ai"],
        data_args["y_task"],
        data_args["a_task"],
        data_args["model"],
    )
    emb_path = f"YOUR_DIR"
    emb_file = f"celeba_y{y_task}_a{a_task}_n{n}_sc{sc}_ci{ci}_ai{ai}.npz"
    # load embeddings
    emb = np.load(emb_path + emb_file)
    activ = emb["activ"]
    ys = emb["ys"]
    attrs = emb["attrs"]
    n_groups = 4
    # generate group labels by ys and attrs
    g = 2 * ys + attrs
    return (activ, ys, g), n_groups, attrs


def run(method, n, sc, ci, ai, y_task, a_task, model, seed=0, tol=1e-3, verbose=False):
    np.random.seed(seed)
    data_args = {
        "n": n,
        "sc": sc,
        "ci": ci,
        "ai": ai,
        "y_task": y_task,
        "a_task": a_task,
        "model": model,
    }

    (tr_x, tr_y, tr_g), n_groups, tr_a = load_emb(data_args, split="tr")
    (te_x, te_y, te_g), n_groups, te_a = load_emb(data_args, split="te")

    # tr_y = (tr_y + 1) / 2
    # te_y = (te_y + 1) / 2
    # TODO: check if we need to normalize the embeddings

    net = nn.Linear(tr_x.shape[1], 1, bias=False)
    opt = torch.optim.Adam(net.parameters(), lr=0.001)
    if method == "ERM":
        loss_fn = nn.BCEWithLogitsLoss()
    elif method == "GroupDRO":
        loss_fn = groupdro_loss
        q = torch.ones(n_groups, dtype=torch.float32)
        tr_g = torch.tensor(tr_g, dtype=torch.int64)
    elif method == "LISA":
        loss_fn = nn.BCEWithLogitsLoss()
        tr_x, tr_y = mixupsample(tr_x, tr_y, tr_a)
    elif method == "remax-margin":
        loss_fn = nn.BCEWithLogitsLoss()
        cnts = np.unique(tr_g, return_counts=True)[1]
        c = cnts / np.sum(cnts)
        c = c / c.max()
    elif method == "oversample":
        loss_fn = nn.BCEWithLogitsLoss()
        over_resample_idx = oversample(tr_g, n_groups)
        tr_x = tr_x[over_resample_idx, :]
        tr_y = tr_y[over_resample_idx]
    elif method == "undersample":
        loss_fn = nn.BCEWithLogitsLoss()
        under_resample_idx = undersample(tr_g, n_groups)
        tr_x = tr_x[under_resample_idx, :]
        tr_y = tr_y[under_resample_idx]

    train_iter = 1000
    log_every = 100

    # convert data
    tr_x = torch.tensor(tr_x, dtype=torch.float32)
    tr_y = torch.tensor(tr_y, dtype=torch.float32).reshape(-1, 1)

    te_x = torch.tensor(te_x, dtype=torch.float32)
    te_y = torch.tensor(te_y, dtype=torch.float32).reshape(-1, 1)

    last_loss = torch.tensor(0.0)
    for t in range(train_iter + 1):
        logits = net(tr_x)
        if method == "GroupDRO":
            loss, q = loss_fn(logits, tr_y, tr_g, q)
        elif method in ["ERM", "LISA", "oversample", "undersample"]:
            loss = loss_fn(logits, tr_y)
        elif method == "remax-margin":
            loss = 0.0
            for i in range(n_groups):
                idx = tr_g == i
                loss += loss_fn(c[i] * logits[idx], tr_y[idx]) / n_groups
        else:
            raise ValueError(f"unknown method {method}")

        # add l2 regularization
        l2_reg = torch.tensor(0.0)
        for param in net.parameters():
            l2_reg += torch.norm(param)
        loss += 1e-6 * l2_reg

        opt.zero_grad()
        loss.backward()
        opt.step()

        if t % log_every == 0:
            if (last_loss - loss.item()) < tol and t > 0:
                break
            else:
                last_loss = loss.item()
            if verbose:
                print(f"{t=} xent {loss.item():.5f}")

    # get training accuracy
    pred = torch.sigmoid(net(tr_x))
    pred = (pred > 0.5).float()
    correct = (pred == tr_y).float().sum()
    tr_acc = correct / len(tr_y)

    # get test accuracy
    pred = torch.sigmoid(net(te_x))
    pred = (pred > 0.5).float()
    correct = (pred == te_y).float().sum()
    te_acc = correct / len(te_y)

    # get worst-case accuracy
    worst_case_acc = []
    for g in range(n_groups):
        idx = te_g == g
        g_pred = pred[idx]
        correct = (g_pred == te_y[idx]).float().sum()
        acc = correct / len(te_y[idx])
        worst_case_acc.append(acc.item())
        # print(f"test acc group {g} {acc.item():.5f}")

    if verbose:
        print(f"train acc {tr_acc.item():.5f}")
        print(f"avg test acc {te_acc.item():.5f}")
        print(f"avg test 0-1 error {(1 - te_acc).item():.5f}")
        print(f"worst-case test error {1 - min(worst_case_acc):.5f}")

    # collect results
    res = copy.deepcopy(data_args)
    res["method"] = method
    res["avg_tr_err"] = 1 - tr_acc.item()
    res["avg_te_err"] = 1 - te_acc.item()
    res["wga_te_err"] = 1 - min(worst_case_acc)

    return res


import argparse

parser = argparse.ArgumentParser(
    description="Run experiments with different methods and configurations."
)
parser.add_argument("--model", type=str, default="resnet", help="Model to use")
parser.add_argument(
    "--methods",
    type=str,
    nargs="+",
    # default=["ERM"],
    default=["ERM", "GroupDRO", "remax-margin", "oversample", "undersample"],
    help="List of methods to use",
)
parser.add_argument("--seed", type=int, default=2, help="Random seed")
parser.add_argument("--y_task", type=int)
parser.add_argument("--a_task", type=int)
parser.add_argument("--shift_type", type=str, choices=["sc", "ci", "ai", "3shifts"])

args = parser.parse_args()


model = args.model
methods = args.methods
n_list = [200, 500, 1000, 2000, 5000, 10000]
# n_list = [200]
task = [args.y_task, args.a_task]
seed = args.seed
verbose = False
shift_type = args.shift_type

if shift_type == "sc":
    sc_list = [0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99]
    # sc_list = [0.05]
    ci_list = [0.5]
    ai_list = [0.5]
elif shift_type == "ci":
    ci_list = [0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99]
    sc_list = [0.5]
    ai_list = [0.5]
elif shift_type == "ai":
    ai_list = [0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99]
    sc_list = [0.5]
    ci_list = [0.5]
elif shift_type == "3shifts":
    # randomly generated
    sc_list = [
        0.68,
        0.5,
        0.3,
        0.51,
        0.4,
        0.57,
        0.17,
        0.51,
        0.55,
        0.47,
        0.49,
        0.52,
        0.23,
        0.12,
        0.36,
        0.15,
        0.5,
        0.49,
        0.48,
        0.26,
        0.28,
        0.68,
        0.27,
        0.25,
        0.23,
        0.59,
        0.72,
        0.67,
        0.64,
        0.7,
    ]
    ci_list = [
        0.35,
        0.47,
        0.49,
        0.36,
        0.11,
        0.46,
        0.4,
        0.65,
        0.24,
        0.22,
        0.29,
        0.42,
        0.24,
        0.26,
        0.39,
        0.39,
        0.7,
        0.65,
        0.4,
        0.67,
        0.38,
        0.81,
        0.39,
        0.81,
        0.6,
        0.42,
        0.58,
        0.85,
        0.24,
        0.42,
    ]
    ai_list = [
        0.47,
        0.31,
        0.71,
        0.27,
        0.65,
        0.36,
        0.49,
        0.7,
        0.47,
        0.52,
        0.55,
        0.16,
        0.64,
        0.8,
        0.45,
        0.68,
        0.65,
        0.66,
        0.57,
        0.17,
        0.75,
        0.69,
        0.53,
        0.12,
        0.21,
        0.46,
        0.61,
        0.69,
        0.36,
        0.33,
    ]
else:
    raise ValueError(f"Unknown shift type {shift_type}")

n_datasets = max([len(sc_list), len(ci_list), len(ai_list)])
total_num_loop = len(n_list) * n_datasets * len(methods)

i = 0
all_res = []
for n in n_list:
    for j in range(n_datasets):
        sc = sc_list[j] if j < len(sc_list) else sc_list[-1]
        ci = ci_list[j] if j < len(ci_list) else ci_list[-1]
        ai = ai_list[j] if j < len(ai_list) else ai_list[-1]
        for method in methods:
            curr_res = run(
                method,
                n,
                sc,
                ci,
                ai,
                task[0],
                task[1],
                model,
                seed,
                verbose=verbose,
            )
            # print(curr_res)
            all_res.append(curr_res)
            i += 1
            if i % 10 == 0:
                print(f"Done {i}/{total_num_loop}", datetime.datetime.now())

path = f"YOUR_DIR"
# check if the path exists
if not os.path.exists(path):
    os.makedirs(path)
pd.DataFrame(all_res).to_csv(
    f"{path}/results_seed{seed}_{shift_type}_y{task[0]}_a{task[1]}.csv"
)

# python linear_probe.py --model clip --seed 0 --shift_type sc --y_task 2  --a_task 31
# python linear_probe.py --model clip --seed 0 --shift_type sc --y_task 21 --a_task  36
# python linear_probe.py --model clip --seed 0 --shift_type sc --y_task 8  --a_task 20
# python linear_probe.py --model clip --seed 0 --shift_type sc --y_task 25 --a_task  19
# python linear_probe.py --model clip --seed 0 --shift_type ci --y_task 2  --a_task 31
# python linear_probe.py --model clip --seed 0 --shift_type ci --y_task 21 --a_task  36
# python linear_probe.py --model clip --seed 0 --shift_type ci --y_task 8  --a_task 20
# python linear_probe.py --model clip --seed 0 --shift_type ci --y_task 25 --a_task  19
# python linear_probe.py --model clip --seed 0 --shift_type ai --y_task 2  --a_task 31
# python linear_probe.py --model clip --seed 0 --shift_type ai --y_task 21 --a_task  36
# python linear_probe.py --model clip --seed 0 --shift_type ai --y_task 8  --a_task 20
# python linear_probe.py --model clip --seed 0 --shift_type ai --y_task 25 --a_task  19
# [[2, 31], [21, 36], [8, 20], [25, 19]]
