import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import copy
import pandas as pd
import os
import datetime


def mixupsample(x, y, a, alpha=2.0):

    def _mix_up(alpha, x1, x2, y1, y2):
        length = min(len(x1), len(x2))
        x1 = x1[:length]
        x2 = x2[:length]
        y1 = y1[:length]
        y2 = y2[:length]

        # n_classes = y1.shape[1]
        bsz = len(x1)
        l = np.random.beta(alpha, alpha, [bsz, 1])
        if len(x1.shape) == 4:
            l_x = np.tile(l[..., None, None], (1, *x1.shape[1:]))
        else:
            l_x = np.tile(l, (1, *x1.shape[1:]))
        # l_y = np.tile(l, [1, n_classes])
        l_y = l.squeeze()

        # mixed_input = l * x + (1 - l) * x2
        mixed_x = l_x * x1 + (1 - l_x) * x2
        mixed_y = l_y * y1 + (1 - l_y) * y2

        return mixed_x, mixed_y

    fn = _mix_up

    all_mix_x, all_mix_y = [], []
    bs = len(x)
    # repeat until enough samples
    while sum(list(map(len, all_mix_x))) < bs:
        start_len = sum(list(map(len, all_mix_x)))
        s = np.random.random() <= 0.5  # self.hparams["LISA_p_sel"]
        # same label, mixup between attributes
        if s:
            for y_i in np.unique(y):
                mask = y == y_i
                x_i, y_i, a_i = x[mask], y[mask], a[mask]
                unique_a_is = np.unique(a_i)
                # # if there are multiple attributes, choose a random pair
                # a_i1, a_i2 = unique_a_is[np.randperm(len(unique_a_is))][:2]
                a_i1, a_i2 = unique_a_is[0], unique_a_is[1]
                mask2_1 = a_i == a_i1
                mask2_2 = a_i == a_i2
                all_mix_x_i, all_mix_y_i = fn(
                    alpha, x_i[mask2_1], x_i[mask2_2], y_i[mask2_1], y_i[mask2_2]
                )
                all_mix_x.append(all_mix_x_i)
                all_mix_y.append(all_mix_y_i)
        # same attribute, mixup between labels
        else:
            for a_i in np.unique(a):
                mask = a == a_i
                x_i, y_i = x[mask], y[mask]
                unique_y_is = np.unique(y)
                # # if there are multiple labels, choose a random pair
                # y_i1, y_i2 = unique_y_is[np.randperm(len(unique_y_is))][:2]
                y_i1, y_i2 = unique_y_is[0], unique_y_is[1]
                # mask2_1 = y_i[:, y_i1].squeeze().bool()
                # mask2_2 = y_i[:, y_i2].squeeze().bool()
                mask2_1 = y_i == y_i1
                mask2_2 = y_i == y_i2
                all_mix_x_i, all_mix_y_i = fn(
                    alpha, x_i[mask2_1], x_i[mask2_2], y_i[mask2_1], y_i[mask2_2]
                )
                all_mix_x.append(all_mix_x_i)
                all_mix_y.append(all_mix_y_i)

        end_len = sum(list(map(len, all_mix_x)))
        # each attribute only has one unique label
        if end_len == start_len:
            return x, y

    all_mix_x = np.concatenate(all_mix_x, axis=0)
    all_mix_y = np.concatenate(all_mix_y, axis=0)

    # shuffle the mixed samples
    all_mix_x = all_mix_x[np.random.permutation(len(all_mix_x))]
    all_mix_y = all_mix_y[np.random.permutation(len(all_mix_y))]

    return all_mix_x[:bs], all_mix_y[:bs]


def oversample(g, n_groups):
    group_counts = []
    for group_idx in range(n_groups):
        group_counts.append((g == group_idx).sum())
    resampled_idx = []
    for group_idx in range(n_groups):
        (idx,) = np.where(g == group_idx)
        if group_counts[group_idx] < max(group_counts):
            for _ in range(max(group_counts) // group_counts[group_idx]):
                resampled_idx.append(idx)
            resampled_idx.append(
                np.random.choice(
                    idx, max(group_counts) % group_counts[group_idx], replace=False
                )
            )
        else:
            resampled_idx.append(idx)
    resampled_idx = np.concatenate(resampled_idx)
    return resampled_idx


def undersample(g, n_groups):
    group_counts = []
    for group_idx in range(n_groups):
        group_counts.append((g == group_idx).sum())
    resampled_idx = []
    for group_idx in range(n_groups):
        (idx,) = np.where(g == group_idx)
        resampled_idx.append(np.random.choice(idx, min(group_counts), replace=False))
    resampled_idx = np.concatenate(resampled_idx)
    return resampled_idx


def groupdro_loss(yhat, y, gs, q):
    losses = F.binary_cross_entropy_with_logits(yhat, y, reduction="none")

    for g in np.unique(gs):
        idx_g = g == gs
        q[g] *= (1e-3 * losses[idx_g].mean()).exp().item()

    q /= q.sum()
    loss = 0
    for g in np.unique(gs):
        idx_g = g == gs
        loss += q[g] * losses[idx_g].mean()

    return loss, q


def load_emb(data_args, split):
    n, sc, ci, ai, model = (
        data_args["n"],
        data_args["sc"],
        data_args["ci"],
        data_args["ai"],
        data_args["model"],
    )
    emb_path = f"//exps/div_explore/{folder}/{model}_embeddings_{split}/"
    emb_file = f"coco_n{n}_sc{sc}_ci{ci}_ai{ai}.npz"
    # check if the file exists
    if not os.path.exists(emb_path + emb_file):
        return None
    else:
        # load embeddings
        emb = np.load(emb_path + emb_file)
        activ = emb["activ"]
        ys = emb["ys"]
        attrs = emb["attrs"]
        n_groups = 4
        # generate group labels by ys and attrs
        g = 2 * ys + attrs
        return (activ, ys, g), n_groups, attrs


def run(method, n, sc, ci, ai, model, seed=0, tol=1e-3, verbose=False):
    np.random.seed(seed)
    data_args = {
        "n": n,
        "sc": sc,
        "ci": ci,
        "ai": ai,
        "model": model,
    }

    tr_emb = load_emb(data_args, split="tr")
    te_emb = load_emb(data_args, split="te")

    if tr_emb is None or te_emb is None:
        return None
    else:
        (tr_x, tr_y, tr_g), n_groups, tr_a = tr_emb
        (te_x, te_y, te_g), n_groups, te_a = te_emb

        # tr_y = (tr_y + 1) / 2
        # te_y = (te_y + 1) / 2
        # TODO: check if we need to normalize the embeddings

        net = nn.Linear(tr_x.shape[1], 1, bias=False)
        opt = torch.optim.Adam(net.parameters(), lr=0.001)
        if method == "ERM":
            loss_fn = nn.BCEWithLogitsLoss()
        elif method == "GroupDRO":
            loss_fn = groupdro_loss
            q = torch.ones(n_groups, dtype=torch.float32)
            tr_g = torch.tensor(tr_g, dtype=torch.int64)
        elif method == "LISA":
            loss_fn = nn.BCEWithLogitsLoss()
            tr_x, tr_y = mixupsample(tr_x, tr_y, tr_a)
        elif method == "remax-margin":
            loss_fn = nn.BCEWithLogitsLoss()
            cnts = np.unique(tr_g, return_counts=True)[1]
            c = cnts / np.sum(cnts)
            c = c / c.max()
        elif method == "oversample":
            loss_fn = nn.BCEWithLogitsLoss()
            over_resample_idx = oversample(tr_g, n_groups)
            tr_x = tr_x[over_resample_idx, :]
            tr_y = tr_y[over_resample_idx]
        elif method == "undersample":
            loss_fn = nn.BCEWithLogitsLoss()
            under_resample_idx = undersample(tr_g, n_groups)
            tr_x = tr_x[under_resample_idx, :]
            tr_y = tr_y[under_resample_idx]

        train_iter = 1000
        log_every = 100

        # convert data
        tr_x = torch.tensor(tr_x, dtype=torch.float32)
        tr_y = torch.tensor(tr_y, dtype=torch.float32).reshape(-1, 1)

        te_x = torch.tensor(te_x, dtype=torch.float32)
        te_y = torch.tensor(te_y, dtype=torch.float32).reshape(-1, 1)

        last_loss = torch.tensor(0.0)
        for t in range(train_iter + 1):
            logits = net(tr_x)
            if method == "GroupDRO":
                loss, q = loss_fn(logits, tr_y, tr_g, q)
            elif method in ["ERM", "LISA", "oversample", "undersample"]:
                loss = loss_fn(logits, tr_y)
            elif method == "remax-margin":
                loss = 0.0
                for i in range(n_groups):
                    idx = tr_g == i
                    loss += loss_fn(c[i] * logits[idx], tr_y[idx]) / n_groups
            else:
                raise ValueError(f"unknown method {method}")

            # add l2 regularization
            l2_reg = torch.tensor(0.0)
            for param in net.parameters():
                l2_reg += torch.norm(param)
            loss += 1e-6 * l2_reg

            opt.zero_grad()
            loss.backward()
            opt.step()

            if t % log_every == 0:
                if (last_loss - loss.item()) < tol and t > 0:
                    break
                else:
                    last_loss = loss.item()
                if verbose:
                    print(f"{t=} xent {loss.item():.5f}")

        # get training accuracy
        pred = torch.sigmoid(net(tr_x))
        pred = (pred > 0.5).float()
        correct = (pred == tr_y).float().sum()
        tr_acc = correct / len(tr_y)

        # get test accuracy
        pred = torch.sigmoid(net(te_x))
        pred = (pred > 0.5).float()
        correct = (pred == te_y).float().sum()
        te_acc = correct / len(te_y)

        # get worst-case accuracy
        worst_case_acc = []
        for g in range(n_groups):
            idx = te_g == g
            g_pred = pred[idx]
            correct = (g_pred == te_y[idx]).float().sum()
            acc = correct / len(te_y[idx])
            worst_case_acc.append(acc.item())
            # print(f"test acc group {g} {acc.item():.5f}")

        if verbose:
            print(f"train acc {tr_acc.item():.5f}")
            print(f"avg test acc {te_acc.item():.5f}")
            print(f"avg test 0-1 error {(1 - te_acc).item():.5f}")
            print(f"worst-case test error {1 - min(worst_case_acc):.5f}")

        # collect results
        res = copy.deepcopy(data_args)
        res["method"] = method
        res["avg_tr_err"] = 1 - tr_acc.item()
        res["avg_te_err"] = 1 - te_acc.item()
        res["wga_te_err"] = 1 - min(worst_case_acc)

        return res


import argparse

parser = argparse.ArgumentParser(
    description="Run experiments with different methods and configurations."
)
parser.add_argument("--model", type=str, default="resnet", help="Model to use")
parser.add_argument(
    "--methods",
    type=str,
    nargs="+",
    # default=["ERM"],
    default=["ERM", "GroupDRO", "remax-margin", "oversample", "undersample"],
    help="List of methods to use",
)
parser.add_argument("--seed", type=int, default=2, help="Random seed")
parser.add_argument("--shift_type", type=str, choices=["sc", "ci", "ai", "3shifts"])
parser.add_argument(
    "--datafolder", type=str, default="coco_v2", help="Folder to save results"
)

args = parser.parse_args()


model = args.model
methods = args.methods
n_list = [200, 500, 1000, 2000, 5000, 10000]
# n_list = [200]
seed = args.seed
verbose = False
shift_type = args.shift_type
folder = args.datafolder

if shift_type == "sc":
    sc_list = [0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99]
    # sc_list = [0.05]
    ci_list = [0.5]
    ai_list = [0.5]
elif shift_type == "ci":
    ci_list = [0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99]
    sc_list = [0.5]
    ai_list = [0.5]
elif shift_type == "ai":
    ai_list = [0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99]
    sc_list = [0.5]
    ci_list = [0.5]
elif shift_type == "3shifts":
    # randomly generated
    # sc_list = [
    #     0.68,
    #     0.5,
    #     0.3,
    #     0.51,
    #     0.4,
    #     0.57,
    #     0.17,
    #     0.51,
    #     0.55,
    #     0.47,
    #     0.49,
    #     0.52,
    #     0.23,
    #     0.12,
    #     0.36,
    #     0.15,
    #     0.5,
    #     0.49,
    #     0.48,
    #     0.26,
    #     0.28,
    #     0.68,
    #     0.27,
    #     0.25,
    #     0.23,
    #     0.59,
    #     0.72,
    #     0.67,
    #     0.64,
    #     0.7,
    # ]
    # ci_list = [
    #     0.35,
    #     0.47,
    #     0.49,
    #     0.36,
    #     0.11,
    #     0.46,
    #     0.4,
    #     0.65,
    #     0.24,
    #     0.22,
    #     0.29,
    #     0.42,
    #     0.24,
    #     0.26,
    #     0.39,
    #     0.39,
    #     0.7,
    #     0.65,
    #     0.4,
    #     0.67,
    #     0.38,
    #     0.81,
    #     0.39,
    #     0.81,
    #     0.6,
    #     0.42,
    #     0.58,
    #     0.85,
    #     0.24,
    #     0.42,
    # ]
    # ai_list = [
    #     0.47,
    #     0.31,
    #     0.71,
    #     0.27,
    #     0.65,
    #     0.36,
    #     0.49,
    #     0.7,
    #     0.47,
    #     0.52,
    #     0.55,
    #     0.16,
    #     0.64,
    #     0.8,
    #     0.45,
    #     0.68,
    #     0.65,
    #     0.66,
    #     0.57,
    #     0.17,
    #     0.75,
    #     0.69,
    #     0.53,
    #     0.12,
    #     0.21,
    #     0.46,
    #     0.61,
    #     0.69,
    #     0.36,
    #     0.33,
    # ]
    sc_list, ci_list, ai_list = (
        [
            0.79,
            0.34,
            0.52,
            0.54,
            0.12,
            0.26,
            0.62,
            0.53,
            0.09,
            0.51,
            0.64,
            0.31,
            0.28,
            0.4,
            0.3,
            0.8,
            0.45,
            0.12,
            0.17,
            0.71,
            0.51,
            0.26,
            0.28,
            0.55,
            0.4,
            0.37,
            0.12,
            0.36,
            0.57,
            0.54,
        ],
        [
            0.46,
            0.12,
            0.5,
            0.1,
            0.65,
            0.46,
            0.78,
            0.29,
            0.2,
            0.43,
            0.48,
            0.56,
            0.5,
            0.47,
            0.48,
            0.91,
            0.45,
            0.47,
            0.78,
            0.8,
            0.55,
            0.49,
            0.67,
            0.5,
            0.13,
            0.26,
            0.91,
            0.72,
            0.1,
            0.67,
        ],
        [
            0.36,
            0.76,
            0.39,
            0.54,
            0.36,
            0.52,
            0.74,
            0.32,
            0.87,
            0.82,
            0.71,
            0.58,
            0.36,
            0.84,
            0.67,
            0.79,
            0.2,
            0.47,
            0.22,
            0.59,
            0.27,
            0.66,
            0.37,
            0.57,
            0.65,
            0.65,
            0.13,
            0.28,
            0.43,
            0.62,
        ],
    )
else:
    raise ValueError(f"Unknown shift type {shift_type}")

n_datasets = max([len(sc_list), len(ci_list), len(ai_list)])
total_num_loop = len(n_list) * n_datasets * len(methods)

i = 0
all_res = []
for n in n_list:
    for j in range(n_datasets):
        sc = sc_list[j] if j < len(sc_list) else sc_list[-1]
        ci = ci_list[j] if j < len(ci_list) else ci_list[-1]
        ai = ai_list[j] if j < len(ai_list) else ai_list[-1]
        for method in methods:
            curr_res = run(
                method,
                n,
                sc,
                ci,
                ai,
                model,
                seed,
                verbose=verbose,
            )
            if curr_res is not None:
                # print(curr_res)
                all_res.append(curr_res)
            i += 1
            if i % 10 == 0:
                print(f"Done {i}/{total_num_loop}", datetime.datetime.now())

path = f"//exps/div_explore/{folder}/{model}_lp"
# check if the path exists
if not os.path.exists(path):
    os.makedirs(path)
pd.DataFrame(all_res).to_csv(f"{path}/results_seed{seed}_{shift_type}.csv")

# python linear_probe_coco.py --model clip --seed 0 --shift_type sc
# python linear_probe_coco.py --model clip --seed 0 --shift_type ci
# python linear_probe_coco.py --model clip --seed 0 --shift_type ai
