import argparse
import copy
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

from tqdm import tqdm


def calc_linearized_pairwise_ranking_loss(
    features_chosen, features_rejected, linear, confidence=1
):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Assume that we are on a CUDA machine, then this should print a CUDA device:
    # print(device)
    # don't need any gradients
    with torch.no_grad():
        scores_chosen = confidence * torch.mv(features_chosen, linear)
        scores_rejected = confidence * torch.mv(features_rejected, linear)
        labels = torch.zeros(len(scores_chosen)).long().to(device)
        mean_score = torch.mean(torch.stack([scores_chosen, scores_rejected]))
        std_score = torch.std(torch.stack([scores_chosen, scores_rejected]))
        median_score = torch.median(torch.stack([scores_chosen, scores_rejected]))
        mad_score = torch.median(
            torch.abs(torch.stack([scores_chosen, scores_rejected]) - median_score)
        )

        loss_criterion = nn.CrossEntropyLoss(reduction="mean")  # sum up losses
        outputs = torch.stack((scores_chosen, scores_rejected), -1)

        cum_log_likelihood = -loss_criterion(outputs, labels).data

    return cum_log_likelihood, mean_score, std_score, median_score, mad_score


def write_weights_likelihood(
    last_layer, loglik, file_writer, mean_score, std_score, median_score, mad_score
):
    if args.debug:
        print("writing weights")
    # convert last layer to numpy array
    np_weights = last_layer.squeeze().cpu().numpy()
    for w in np_weights:
        file_writer.write(str(w) + ",")
    file_writer.write(str(mean_score.item()) + ",")
    file_writer.write(str(std_score.item()) + ",")
    file_writer.write(str(median_score.item()) + ",")
    file_writer.write(str(mad_score.item()) + ",")
    file_writer.write(str(loglik.item()) + "\n")


def lexicase_pairwise_prefs(
    population, features_chosen, features_rejected, confidence=1, positivity=False
):
    """use (i,j) indices and precomputed feature counts to do faster pairwise ranking loss"""
    device = torch.device("cuda")
    # Assume that we are on a CUDA machine, then this should print a CUDA device:
    # print(device)
    # don't need any gradients
    with torch.no_grad():
        # do matrix multiply with last layer of network and the demo_cnts
        # print(list(reward_net.fc2.parameters()))
        all_scores = []
        median_scores = []
        std_scores = []
        mean_scores = []
        mad_scores = []
        for linear in population:
            scores_chosen = confidence * torch.mv(features_chosen, linear)
            scores_rejected = confidence * torch.mv(features_rejected, linear)
            mean_score = torch.mean(torch.stack([scores_chosen, scores_rejected]))
            std_score = torch.std(torch.stack([scores_chosen, scores_rejected]))
            median_score = torch.median(torch.stack([scores_chosen, scores_rejected]))
            mad_score = torch.median(
                torch.abs(torch.stack([scores_chosen, scores_rejected]) - median_score)
            )

            # positivity prior
            if positivity:
                if scores_chosen.min < 0 or scores_rejected.min < 0:
                    # give zeros to all, assign the worst score
                    all_scores.append(np.zeros(len(scores_chosen)))
                    median_scores.append(0)
                    continue

            # evaluate demo prefs
            outputs = scores_chosen > scores_rejected
            outputs = outputs.int()
            all_scores.append(outputs.cpu().numpy())
            mean_scores.append(mean_score.item())
            std_scores.append(std_score.item())
            median_scores.append(median_score.item())
            mad_scores.append(mad_score.item())

    all_scores = np.stack(all_scores)
    num_features = all_scores.shape[1]
    selected = []
    n_select = all_scores.shape[0]  # full population
    mean_scores = torch.tensor(mean_scores)
    median_scores = torch.tensor(median_scores)
    std_scores = torch.tensor(std_scores)
    mad_scores = torch.tensor(mad_scores)

    # lexicase selection
    for _ in range(n_select):
        features = np.arange(num_features)
        np.random.shuffle(features)
        pool = np.ones(all_scores.shape[0], dtype=bool)  # logical array if selected
        while (
            len(features) != 0 and np.sum(pool) != 1
        ):  # while we still have cases to use
            feature = features[0]
            features = features[1:]

            best = np.max(all_scores[pool, feature])

            # filter selected pop with this feature. If it filters everyone, skip
            pool = np.logical_and(
                pool,
                all_scores[:, feature] >= best,
            )

        sel = np.random.choice(np.where(pool == 1)[0])
        selected.append(sel)
        # print(f"selected individual: {all_scores[sel]}")
        # print(f"selected individual's score: {np.sum(all_scores[sel])}")

    return (
        [population[sel] for sel in selected],
        np.sum(all_scores[selected], axis=1),
        mean_scores[selected],
        std_scores[selected],
        median_scores[selected],
        mad_scores[selected],
    )


def lexicase_search(
    features_chosen,
    features_rejected,
    last_layer,
    num_steps,
    step_stdev,
    weight_output_filename,
    pop_size=1000,
    ds_rate=1,
):
    """run metropolis hastings MCMC and record weights in chain"""
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    writer = open(weight_output_filename, "w")

    with torch.no_grad():
        linear = last_layer
        linear = torch.nn.functional.normalize(linear, p=2, dim=0)

    cur_reward = copy.deepcopy(linear)

    best_score = 0
    population = [copy.deepcopy(linear) for _ in range(pop_size)]
    for gen in tqdm(range(num_steps // pop_size)):
        new_population = []
        for cur_reward in population:
            # take a proposal step
            proposal_reward = copy.deepcopy(cur_reward)
            # add random noise to weights of last layer
            with torch.no_grad():
                proposal_reward.add_(
                    torch.randn(proposal_reward.size()).to(device) * step_stdev
                )
                # normalize the weight vector...
                proposal_reward = torch.nn.functional.normalize(
                    proposal_reward, p=2, dim=0
                )
            new_population.append(proposal_reward)

        # lexicase selection
        ds_indices = np.arange(len(features_chosen))
        np.random.shuffle(ds_indices)
        ds_indices = ds_indices[: int(len(features_chosen) * ds_rate)]

        population, scores, mean_scores, std_scores, median_scores, mad_scores = (
            lexicase_pairwise_prefs(
                new_population,
                features_chosen[ds_indices],
                features_rejected[ds_indices],
            )
        )

        # save chain of weights
        for (
            cur_reward,
            cur_score,
            mean_score,
            std_score,
            median_score,
            mad_score,
        ) in zip(
            population, scores, mean_scores, std_scores, median_scores, mad_scores
        ):
            write_weights_likelihood(
                cur_reward,
                cur_score,
                writer,
                mean_score,
                std_score,
                median_score,
                mad_score,
            )
            if cur_score > best_score:
                best_score = cur_score
                print(gen, "gens, updating best to", best_score)

    print("Finished, best score", best_score)
    writer.close()


def mcmc_map_search(
    features_chosen,
    features_rejected,
    last_layer,
    num_steps,
    step_stdev,
    weight_output_filename,
):
    """run metropolis hastings MCMC and record weights in chain"""
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    writer = open(weight_output_filename, "w")

    with torch.no_grad():
        linear = last_layer
        # linear.add_(torch.randn(linear.size()).to(device) * step_stdev)
        linear = torch.nn.functional.normalize(linear, p=2, dim=0)

    starting_loglik, cur_mean_score, cur_std_score, cur_median_score, cur_mad_score = (
        calc_linearized_pairwise_ranking_loss(
            features_chosen, features_rejected, linear
        )
    )

    map_loglik = starting_loglik

    cur_reward = copy.deepcopy(linear)
    cur_loglik = starting_loglik

    reject_cnt = 0
    accept_cnt = 0

    all_rewards = []
    start_t = time.time()

    for i in tqdm(range(num_steps)):
        if args.debug:
            print(
                "step",
                i,
                "time:",
                time.strftime("%H:%M:%S", time.gmtime(time.time() - start_t)),
            )

        # take a proposal step
        proposal_reward = copy.deepcopy(cur_reward)
        # add random noise to weights of last layer
        with torch.no_grad():
            proposal_reward.add_(
                torch.randn(proposal_reward.size()).to(device) * step_stdev
            )

            # normalize the weight vector...
            proposal_reward = torch.nn.functional.normalize(proposal_reward, p=2, dim=0)

        (
            prop_loglik,
            prop_mean_score,
            prop_std_score,
            prop_median_score,
            prop_mad_score,
        ) = calc_linearized_pairwise_ranking_loss(
            features_chosen, features_rejected, proposal_reward
        )
        if args.debug:
            print("proposal loglik", prop_loglik.item())
            print("cur loglik", cur_loglik.item())
        if prop_loglik > cur_loglik:
            # print()
            # accept always
            if args.debug:
                print("accept")
            accept_cnt += 1
            cur_reward = copy.deepcopy(proposal_reward)
            cur_loglik = prop_loglik
            cur_mean_score = prop_mean_score
            cur_std_score = prop_std_score
            cur_median_score = prop_median_score
            cur_mad_score = prop_mad_score

            # check if this is best so far
            if prop_loglik > map_loglik:
                map_loglik = prop_loglik
                map_reward = copy.deepcopy(proposal_reward)
                print()
                print(
                    "step",
                    i,
                    "time:",
                    time.strftime("%H:%M:%S", time.gmtime(time.time() - start_t)),
                )
                print("proposal loglik", prop_loglik.item())

                print("updating map to ", prop_loglik.item())
        else:
            # accept with prob exp(prop_loglik - cur_loglik)
            if np.random.rand() < torch.exp(prop_loglik - cur_loglik).item():
                # print()
                # print("step", i)
                if args.debug:
                    print("proposal loglik", prop_loglik.item())
                    print("probabilistic accept")
                accept_cnt += 1
                cur_reward = copy.deepcopy(proposal_reward)
                cur_loglik = prop_loglik
                cur_mean_score = prop_mean_score
                cur_std_score = prop_std_score
                cur_median_score = prop_median_score
                cur_mad_score = prop_mad_score
            else:
                # reject and stick with cur_reward
                if args.debug:
                    print("reject")
                reject_cnt += 1

        all_rewards.append(cur_reward.cpu().numpy())

        # save chain of weights
        write_weights_likelihood(
            cur_reward,
            cur_loglik,
            writer,
            cur_mean_score,
            cur_std_score,
            cur_median_score,
            cur_mad_score,
        )

    print("num rejects", reject_cnt)
    print("num accepts", accept_cnt)
    writer.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument("--seed", default=0, help="random seed for experiments")
    parser.add_argument(
        "--features_dir",
        default=".",
        help="path to directory that contains pre-extracted features and linear weights",
    )
    parser.add_argument(
        "--num_mcmc_steps",
        default=200000,
        type=int,
        help="number of proposals to generate for MCMC",
    )
    parser.add_argument(
        "--mcmc_step_size",
        default=0.01,
        type=float,
        help="proposal step is gaussian with zero mean and mcmc_step_size stdev",
    )
    parser.add_argument(
        "--method",
        help="filename including path to write the chain to",
        default="brex",
    )
    parser.add_argument(
        "--debug",
        help="use fewer demos to speed things up while debugging",
        action="store_true",
    )
    parser.add_argument("--encoding_dims", help="size of latent space", type=int)
    parser.add_argument(
        "--demo_dist",
        default="uniform",
        help="uniform: uniform levels of demos, low: more low-level, mid: more mid-level",
    )
    parser.add_argument(
        "--flip_ratio",
        default=0,
        type=float,
        help="ratio to flip the preference labels",
    )
    args = parser.parse_args()

    if args.debug:
        args.num_mcmc_steps = 2000

    # set seeds
    seed = int(args.seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    start_t = time.time()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("start experiments on:", device)

    features_chosen = torch.load(
        os.path.join(args.features_dir, "features_chosen.pt"), map_location="cpu"
    )
    features_rejected = torch.load(
        os.path.join(args.features_dir, "features_rejected.pt"), map_location="cpu"
    )
    features_chosen = torch.tensor(features_chosen).to(device)
    features_rejected = torch.tensor(features_rejected).to(device)

    linear = torch.load(
        os.path.join(args.features_dir, "linear_weights.pt"), map_location="cpu"
    )
    linear = linear.squeeze().float().to(device)
    print("time to load data:", time.time() - start_t)

    output_dir = os.path.join(os.path.dirname(args.features_dir), args.method)
    os.makedirs(output_dir, exist_ok=True)

    mcmc_output_filename = os.path.join(output_dir, "mcmc_chain.txt")

    num_features = linear.shape[0]
    print("reward is linear combination of ", num_features, "features")

    # run random search over weights
    # best_reward = random_search(reward_net, demonstrations, 40, stdev = 0.01)

    if args.method == "brex":
        search_func = mcmc_map_search
    elif args.method == "lexicase":
        search_func = lexicase_search
    else:
        raise NotImplementedError("method not implemented")

    search_func(
        features_chosen,
        features_rejected,
        linear,
        args.num_mcmc_steps,
        args.mcmc_step_size,
        mcmc_output_filename,
    )
