# Based off of https://github.com/cassidylaidlaw/hidden-context/blob/main/hidden_context/synthetic_experiments.py
# could do with

import argparse
import inspect
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Tuple

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F  # noqa: N812
import tqdm

# stuff for b-rex
from src.bayesian_rex import (
    calc_linearized_pairwise_ranking_loss,
    compute_l2,
    mcmc_map_search,
)

from src.popl import popl_search, select_one, select_one_best
from matplotlib.gridspec import GridSpec
from torch import nn, optim
from torch.optim.lr_scheduler import ExponentialLR

from experiments.synthetic_stateless_utils import BaseRewardModel

import hydra


def get_preferences(feature_model, sample_state, reward_fn, env_name, batch_size, flip_percentage, group_ratio, device):
    reward_model = feature_model.to(device)
    state0 = sample_state(batch_size).to(device)
    state1 = sample_state(batch_size).to(device)
    features0 = reward_model.get_penultimate_layer(state0).to(device)
    features1 = reward_model.get_penultimate_layer(state1).to(device)

    # group ratio is the ratio of z=1 preferences to z=0

    # identity is held constant for a preference comparison
    if env_name == "1d_identity":
        identity = torch.multinomial(
            torch.tensor([1 - group_ratio, group_ratio], dtype=torch.float, device=device), batch_size, replacement=True)
        # identity = torch.rand(state0.shape, device=state0.device) < 0.5
        rewards0 = reward_fn(state0, identity)
        rewards1 = reward_fn(state1, identity)
    else:
        rewards0 = reward_fn(state0)
        rewards1 = reward_fn(state1)

    preferences = (rewards1 > rewards0).long().to(device)

    # randomly flip flip_percentage of the preferences
    flip_count = int(flip_percentage * preferences.shape[0])
    flip_indices = torch.randperm(preferences.shape[0])[:flip_count]
    preferences[flip_indices] = 1 - preferences[flip_indices]

    return preferences, state0, state1, features0, features1,


def gaussian_mutation(population, mutation_rate):
    # individual is a tensor of shape (popsize, hidden_dim)
    # mutation rate is a float between 0 and 1
    # we will mutate each element of the individual with probability mutation_rate

    mutated_population = population.clone()
    # add gaussian noise mean 0 std mutation_rate
    mutated_population += torch.randn_like(mutated_population) * mutation_rate

    return mutated_population


def index_mutation(population, mutation_rate):
    # pick indices to mutate the individuals for each individual
    indices = torch.rand_like(population) < mutation_rate
    masked_pop = torch.zeros_like(population)
    masked_pop[indices] = 1

    # add gaussian noise to the masked population

    mutated_population = population + torch.randn_like(population) * masked_pop

    return mutated_population


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)


def train_rlhf(
    reward_model: BaseRewardModel,
    preferences: torch.Tensor,
    state0: torch.Tensor,
    state1: torch.Tensor,
    lr: float,
    num_iterations: int,
    device: torch.device,
) -> BaseRewardModel:
    optimizer = optim.Adam(reward_model.parameters(), lr=lr)
    scheduler = ExponentialLR(optimizer, gamma=(
        1e-5 / lr) ** (1 / num_iterations))
    reward_model.to(device).train()
    progress_bar = tqdm.tqdm(range(num_iterations))

    print(f"state shape: {state0.shape}")
    # batch the preferences and states
    batch_size = 32
    preferences = preferences.to(device)
    state0 = state0.to(device)
    state1 = state1.to(device)

    for _ in progress_bar:
        optimizer.zero_grad()

        batch_indices = torch.randperm(state0.shape[0])[:batch_size]
        batch_state0 = state0[batch_indices]
        batch_state1 = state1[batch_indices]
        batch_preferences = preferences[batch_indices]

        loss = - \
            reward_model.preference_logp(
                batch_state0, batch_state1, batch_preferences).mean()
        loss.backward()
        optimizer.step()
        scheduler.step()
        progress_bar.set_description(
            # type: ignore
            f"loss = {loss.item():.2f}    lr = {scheduler.get_lr()[0]:.2e}"
        )

    return reward_model


def get_curated_reward_fns(
    set_of_reward_fns,
    true_reward_fn,
    feature_model,
    sample_state,
    sample_num=50,
    title="",
    device=torch.device("cuda"),
):
    # get the last layer features for all states
    state_interpolated = torch.linspace(0, 1, 100).to(device)
    features_interpolated = feature_model.get_penultimate_layer(
        state_interpolated[:, None]
    )

    # states are interpolated from 0 to 1
    states0 = []
    states1 = []
    # get a full nxn arrangement
    for i in range(1, sample_num):
        for j in range(1, sample_num):
            states0.append(i / sample_num)
            states1.append(j / sample_num)

    s0 = torch.tensor(states0).to(device)
    s1 = torch.tensor(states1).to(device)
    f0 = feature_model.get_penultimate_layer(s0[:, None])
    f1 = feature_model.get_penultimate_layer(s1[:, None])

    # get ground truth preferences for each hidden context group
    r0_z1 = true_reward_fn(s0, torch.tensor(1))
    r1_z1 = true_reward_fn(s1, torch.tensor(1))
    r0_z0 = true_reward_fn(s0, torch.tensor(0))
    r1_z0 = true_reward_fn(s1, torch.tensor(0))
    pp0 = (r1_z0 > r0_z0).long()
    pp1 = (r1_z1 > r0_z1).long()

    print(f"got the features")

    # curated reward function for each hidden context class
    ind0, score0 = select_one_best(set_of_reward_fns, f0, f1, pp0)
    print(f"ind0: {ind0.shape}")
    ret0 = features_interpolated @ ind0.T

    ind1, score1 = select_one_best(set_of_reward_fns, f0, f1, pp1)
    ret1 = features_interpolated @ ind1.T

    # normalize ret0 and ret1
    ret0 = ret0 / torch.norm(ret0)
    ret1 = ret1 / torch.norm(ret1)

    # graph these individuals, making it pretty
    # scale each line by its own min and max
    fig, ax1 = plt.subplots(figsize=(5, 5))
    scaled_ret0 = get_monotone(ret0).cpu().detach().numpy()
    scaled_ret1 = get_monotone(ret1).cpu().detach().numpy()

    # ax2 = ax1.twinx()
    ax1.plot(
        state_interpolated.cpu().detach().numpy(),
        scaled_ret0,
        label=r"$z=0$",
        color="green",
        linewidth=2,
    )
    ax1.plot(
        state_interpolated.cpu().detach().numpy(),
        scaled_ret1,
        label=r"$z=1$",
        color="red",
        linewidth=2,
        linestyle="dashed",
    )

    print(f"done plotting")
    ax1.set_ylabel(r"u(a)")
    lines, labels = ax1.get_legend_handles_labels()
    ax1.set_xlabel("a")

    ax1.legend(lines, labels, loc=0)

    folder_name = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    plt.savefig(
        f"{folder_name}/{title}_curated_individuals.png"
    )
    plt.close()

    print(f"done with curation")


def train_reward_extrapolation(
    reward_model: BaseRewardModel,
    reward_fn: Callable[[torch.Tensor], torch.Tensor],
    sample_state: Callable[[int], torch.Tensor],
    method: Literal["b-rex", "popl"],
    batch_size: int,
    lr: float,
    num_iterations: int,
    popsize: int,  # only used for popl
    device: torch.device,
    likelihood: str = "bradley-terry",
    flip_percentage: float = 0,
    group_ratio: float = 0,
    env_name: str = "1d_identity",
    confidence: float = 1,
    downsample_level: float = 1,
    elitism: bool = False,
    pretrain: bool = False,
    pretrain_steps: int = 0,
    pretrain_lr: float = 0.01,
    mutation_fn: Optional[Callable[[
        torch.Tensor, float], torch.Tensor]] = gaussian_mutation,
) -> Tuple[BaseRewardModel, torch.Tensor]:

    # generate preferences and features
    preferences, state0, state1, features0, features1 = get_preferences(
        reward_model, sample_state, reward_fn, env_name, batch_size, flip_percentage, group_ratio, device)

    all_reward_funcs = []

    reward_model_pretrain = None

    # Do we pretrain with the same preferences as the main training loop?
    # I think that it makes sense if we do
    if pretrain:
        reward_model_pretrain = train_rlhf(
            reward_model=reward_model,
            preferences=preferences,
            state0=state0,
            state1=state1,
            lr=pretrain_lr,
            num_iterations=pretrain_steps,
            device=device,
        )

    if method == "b-rex":
        best_reward_lastlayer, chain, logliks = mcmc_map_search(
            reward_model_pretrain,
            preferences,
            features0,
            features1,
            num_iterations,
            lr,
            confidence,
            mutation_fn=mutation_fn,
            likelihood=likelihood
        )

        all_reward_funcs = torch.from_numpy(chain).to(device)
        best_reward_lastlayer = best_reward_lastlayer.to(device)

    elif method == "popl":
        population, scores, best_scores = popl_search(
            reward_model_pretrain,
            preferences,
            features0,
            features1,
            popsize,
            num_iterations,
            lr,
            normalize=True,
            downsample_level=downsample_level,
            elitism=elitism,
            bt=None,
            mutation_fn=mutation_fn,
        )

        all_reward_funcs = population
        sum_scores = torch.sum(scores, dim=1)
        best_reward_lastlayer = population[torch.argmax(sum_scores)]

        best_reward_lastlayer = best_reward_lastlayer.to(device)

        # need to convert from np to torch params
        temp_lastlayer = torch.nn.Linear(
            reward_model.last_layer.in_features, 1, bias=False).to(device)

        for param in temp_lastlayer.parameters():
            # copy param from best_reward_lastlayer
            param.data = best_reward_lastlayer

        best_reward_lastlayer = temp_lastlayer

    best_reward_model = reward_model
    best_reward_model.last_layer = best_reward_lastlayer

    return best_reward_model, all_reward_funcs


def reward_fn_1d_identity(state: torch.Tensor, identity: torch.Tensor) -> torch.Tensor:
    state = state.squeeze(-1)
    identity = identity.squeeze(-1)
    rewards = state.clone()
    rewards[(state >= 0.8) & (identity == 1)] *= 2
    rewards[(state >= 0.8) & (identity == 0)] *= 0

    return rewards


def reward_fn_1d(state: torch.Tensor, identity: torch.Tensor) -> torch.Tensor:
    # identity ignored
    state = state.squeeze(-1)
    rewards = state.clone()
    double_rewards = torch.rand(rewards.shape, device=rewards.device) < 0.5
    rewards[(state >= 0.8) & double_rewards] *= 2
    rewards[(state >= 0.8) & ~double_rewards] *= 0
    return rewards


def reward_fn_2d(state: torch.Tensor, identity: torch.Tensor) -> torch.Tensor: \
        # identity ignored
    x, y = state[..., 0], state[..., 1]
    b = torch.rand(x.shape, device=x.device) < 1 - x * y
    rewards = torch.empty_like(x)
    rewards[b] = (y / (1 - x * y))[b]
    rewards[~b] = 0
    return rewards


# for the safety vs speed experiment
def reward_fn_2d_identity(state: torch.Tensor, identity: torch.Tensor) -> torch.Tensor:
    # state is a pair of safety and speed
    safety = state[..., 0]
    speed = state[..., 1]
    identity = identity.squeeze(-1)
    rewards = torch.empty_like(safety)
    rewards[identity == 0] = 3 * safety[identity == 0] + speed[identity == 0]
    rewards[identity == 1] = safety[identity == 1] + 3 * speed[identity == 1]
    return rewards


# this reward function implements cyclical preferences.
# if identity is 0, then the reward is equal to the state
# if identity is 1, then the reward is equal to a + 0.66 if a < 0.33 and (a-0.33) otherwise
# if identity is 2, then the reward is equal to a + 0.33 if a < 0.66 and (a-0.66) otherwise
def reward_fn_cycles(state: torch.Tensor, identity: torch.Tensor) -> torch.Tensor:
    state = state.squeeze(-1)
    identity = identity.squeeze(-1)
    rewards = state.clone()
    rewards[identity == 1] = (state[identity == 1] + 0.66) % 1
    rewards[identity == 2] = (state[identity == 2] + 0.33) % 1
    return rewards


def reward_fn_default(state: torch.Tensor, identity: torch.Tensor) -> torch.Tensor:
    return state.squeeze(-1).clone()


# this function removes all reward hypotheses that are equivalent to each other based on how they rank
def remove_same_rankings(returns):
    # returns is a matrix of all returns for each hypothesis, there is a vector showing return for a = [0, 0.01, 0.02, ... 1]
    # returns a vector of all_hypotheses that are not equivalent
    print(f"SHAPE OF RETURNS: {returns.shape}")

    # for each hypotheses, we need to rank the states
    ranks = torch.argsort(returns, dim=0)
    print(f"ranks: {ranks}")

    # now, we need to combine all hypotheses that have the same ranking
    unique_ranks_output = torch.unique(ranks, dim=1, return_inverse=True)
    print(f"unique_ranks: {unique_ranks_output}")
    print(f"inverse: {unique_ranks_output[1]}")

    print(f"number unique: {unique_ranks_output[0].shape[1]}")

    return unique_ranks_output[1]


def get_monotone(r_values):
    # given an assignment of r for an interpolation of x, we want to find the equivalent ranking assignment for graphing purposes
    # argsort the r values
    sorted_inds = torch.argsort(r_values, dim=0)
    ranks = torch.argsort(sorted_inds, dim=0)

    return ranks


def sample_state_fn(n: int) -> torch.Tensor:
    return torch.rand((n, 1))
    # return torch.clip(torch.randn((n, 1)) * 0.1 + 0.2, 0, 1)


def sample_state_fn_2d(n: int) -> torch.Tensor:
    return torch.rand((n, 2))


# given a set of returns, and a set of preferences,
# we calculate whether, for each preference, there is at least one hypothesis that is consistent with the preference
def preference_pass_matrix(returns, preference_matrix):
    passes = torch.zeros(
        (preference_matrix.shape[0],
         preference_matrix.shape[1]), dtype=torch.int
    )
    for i in range(preference_matrix.shape[0]):
        for j in range(preference_matrix.shape[1]):
            if preference_matrix[i, j] == 1:
                passes[i, j] = torch.sum(returns[i, :] >= returns[j, :])
            else:
                passes[i, j] = torch.sum(returns[i, :] <= returns[j, :])

    # get sum of passes
    return passes


# in the categorical case, we simply sample many times and see if the preference is consistent with the hypothesis
def preference_pass_matrix_categorical(returns, preferences):
    pass


@hydra.main(version_base=None, config_path="config", config_name="synthetic_sweep")
def main(cfg):
    reward_fn: Callable[[torch.Tensor, torch.Tensor],
                        torch.Tensor]  # state x identity -> reward
    state_dim: int

    env_name = cfg.env_name

    if env_name == "1d":
        reward_fn = reward_fn_1d
        state_dim = 1
    elif env_name == "1d_identity":
        reward_fn = reward_fn_1d_identity
        state_dim = 1
    elif env_name == "cycles":
        reward_fn = reward_fn_cycles
        state_dim = 1
    elif env_name == "default":
        reward_fn = reward_fn_default
        state_dim = 1
    else:
        raise ValueError(f"Unknown environment name: {env_name}")

    reward_model_kwargs = {}
    out_dir = os.getcwd()
    num_iterations = cfg.num_iterations
    batch_size = cfg.batch_size
    flips = cfg.flips
    ratio = cfg.group_ratio  # 0 means all z=0, 1 means all z=1

    set_seed(cfg.seed)

    experiment_dir = os.path.join(
        out_dir,
        env_name
    )

    def sample_state(n): return sample_state_fn(n)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    reward_model_class = BaseRewardModel
    kwargs = dict(reward_model_kwargs)

    if cfg.method == "rlhf":

        reward_model = reward_model_class(state_dim=state_dim, **kwargs)
        reward_model = train_rlhf(
            reward_model=reward_model,
            reward_fn=reward_fn,
            sample_state=sample_state,
            batch_size=batch_size,
            lr=lr,
            num_iterations=num_iterations,
            device=device,
            flip_percentage=flips,
        )
        reward_model.eval()

    reward_model_kwargs["num_layers"] = 2
    reward_model_kwargs["hidden_dim"] = 128
    if cfg.use_binning_encoding:
        reward_model_kwargs["use_encoding"] = True
        reward_model_kwargs["num_layers"] = 1
        reward_model_kwargs["hidden_dim"] = 10

    folder_name = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    if cfg.mutation_fn == "gaussian":
        mutation_fn = gaussian_mutation
    elif cfg.mutation_fn == "index":
        mutation_fn = index_mutation
    else:
        raise ValueError(f"Unknown mutation function: {cfg.mutation_fn}")

    if cfg.method == "b-rex" or cfg.method == "both":
        # Bayesian REX

        b_rex_model = BaseRewardModel(
            state_dim=state_dim, **(dict(reward_model_kwargs))
        ).to(device)

        b_rex_model, brex_lastlayers = train_reward_extrapolation(
            reward_model=b_rex_model,
            method="b-rex",
            reward_fn=reward_fn,
            sample_state=sample_state,
            batch_size=batch_size,
            lr=cfg.brex_lr,
            popsize=0,  # required but unused
            num_iterations=num_iterations,
            device=device,
            flip_percentage=flips,
            group_ratio=ratio,
            env_name=env_name,
            confidence=cfg.brex_confidence,
            pretrain=cfg.pretrain,
            pretrain_steps=cfg.pretrain_steps,
            pretrain_lr=cfg.pretrain_lr,
            mutation_fn=mutation_fn,
        )
        b_rex_model.eval()
        brex_lastlayers = brex_lastlayers.squeeze(1)

        print(f"got the b_rex model: {b_rex_model}")
        print(f"b_rex lastlayers shape: {brex_lastlayers.shape}")

        # save to file
        torch.save(brex_lastlayers, f"{folder_name}/b_rex_lastlayers.pt")
        torch.save(b_rex_model, f"{folder_name}/b_rex_model.pt")

        # get curated reward functions
        get_curated_reward_fns(
            brex_lastlayers,
            reward_fn,
            b_rex_model,
            sample_state,
            sample_num=cfg.sample_num,
            title="b_rex"
        )

    if cfg.method == "popl" or cfg.method == "both":
        # Lexicase
        popl_model = BaseRewardModel(
            state_dim=state_dim, **(dict(reward_model_kwargs)))

        popl_model, popl_lastlayers = train_reward_extrapolation(
            reward_model=popl_model,
            method="popl",
            reward_fn=reward_fn,
            sample_state=sample_state,
            batch_size=batch_size,
            lr=cfg.popl_lr,
            popsize=cfg.lex_popsize,
            num_iterations=cfg.lex_iterations,
            device=device,
            flip_percentage=flips,
            elitism=cfg.elitism,
            group_ratio=ratio,
            env_name=env_name,
            downsample_level=cfg.lex_downsample_level,
            pretrain=cfg.pretrain,
            pretrain_steps=cfg.pretrain_steps,
            pretrain_lr=cfg.pretrain_lr,
            mutation_fn=mutation_fn,
        )
        popl_model.eval()

        print(f"got the popl model: {popl_model}")
        print(f"popl lastlayers shape: {popl_lastlayers.shape}")

        # save to file
        torch.save(popl_lastlayers, f"{folder_name}/popl_lastlayers.pt")
        torch.save(popl_model, f"{folder_name}/popl_model.pt")

        get_curated_reward_fns(
            popl_lastlayers,
            reward_fn,
            popl_model,
            sample_state,
            sample_num=cfg.sample_num,
            title="popl"
        )

        print(f"done with outer loop")


if __name__ == "__main__":
    main()
    print(f"done with main")
