import argparse
import os
import random
import sys
import time

import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange

from src.lexicase import select_from_scores

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# numpy


def np_softmax(x):
    """Compute softmax values for each sets of scores in x. Softmax is on last axis"""
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / e_x.sum(axis=-1, keepdims=True)

# torch version


def softmax(x):
    """Compute softmax values for each sets of scores in x. Softmax is on last axis"""
    e_x = torch.exp(x - torch.max(x, dim=-1, keepdim=True)[0])
    return e_x / e_x.sum(dim=-1, keepdim=True)


# numpy version
def calc_popl_policy_score(last_layer, features0, features1, action0, action1, prefs, alpha=None):
    '''use (i,j) indices and precomputed feature counts to do faster pairwise ranking loss'''

    if alpha == None:
        popsize = last_layer.shape[-1]

        weights = last_layer  # append bias and weights from last fc layer together

        # print(f"unnorm_actions0 shape: {features0.shape} {weights.shape}")
        # print(f"unnorm_actions1 shape: {features1.shape} {weights.shape}")
        # unnorm_actions0 = np.matmul(features0[:, None], weights).transpose(1, 0, 2, 3)
        # unnorm_actions1 = np.matmul(features1[:, None], weights).transpose(1, 0, 2, 3)
        # converted to torch
        # features0 shape = (batch_size, snippet_length, feature_dim)
        # weights = (popoulation_size, feature_dim)
        # goal -> (population_size, batch_size, snippet_length, action_size)
        unnorm_actions0 = torch.einsum('ijk,lmk->lijm', features0, weights)
        unnorm_actions1 = torch.einsum('ijk,lmk->lijm', features1, weights)

        #print(f"unnorm_actions0 shape: {unnorm_actions0.shape}")
       # print(f"unnorm_actions1 shape: {unnorm_actions1.shape}")

        #print(f"unnorm_actions0 shape: {unnorm_actions0.shape}")
        #print(f"unnorm_actions1 shape: {unnorm_actions1.shape}")
        #print(unnorm_actions0[0, 0, 0])
    
        # softmax
        #a0 = softmax(unnorm_actions0)
        #a1 = softmax(unnorm_actions1)

        a0 = unnorm_actions0
        a1 = unnorm_actions1

        #print(f"action0 shape: {action0.shape}")
        #print(f"a0 shape: {a0.shape}")
        #print(a0[0, 0, 0])

        # log_probs0 = np.sum(np.log(np.sum(a0 * action0[None], axis=-1) + 1e-8), axis=-1)
        # log_probs1 = np.sum(np.log(np.sum(a1 * action1[None], axis=-1) + 1e-8), axis=-1)
        #check if action0 is onehot
        if (action0.sum(dim=-1) == 1 & (action0.eq(0) | action0.eq(1)).all(dim=-1)).all():
            log_probs0 = torch.sum(
                torch.log(torch.sum(a0 * action0[None], dim=-1) + 1e-8), dim=-1)
            log_probs1 = torch.sum(
                torch.log(torch.sum(a1 * action1[None], dim=-1) + 1e-8), dim=-1)
        else:
            log_probs0 = -torch.square(a0 - action0[None]).sum(dim=-1)
            log_probs1 = -torch.square(a1 - action1[None]).sum(dim=-1)
        

        adv0 = 1 * log_probs0
        adv1 = 1 * log_probs1

        if len(adv0.shape) > 2:
            # adv0 = np.sum(adv0, axis=-1)
            # adv1 = np.sum(adv1, axis=-1)
            adv0 = torch.sum(adv0, dim=-1)
            adv1 = torch.sum(adv1, dim=-1)

        logit10 = adv1 - 1 * adv0
        logit01 = adv0 - 1 * adv1

        scores = ((adv1 > adv0) == prefs).float()

        return scores
    else:
        return calc_policy_scores(last_layer, features0, features1, action0, action1, prefs, alpha)


# calculates a fitness like metric for a candidate, given a pairwise preference in the dataset, used for selection.
def calc_policy_scores(last_layer, features0, features1, action_true0, action_true1, prefs, alpha=10, bias=1):
    '''use (i,j) indices and precomputed feature counts to do faster pairwise ranking loss, bt is boltzmann temperature'''

    popsize = last_layer.shape[0]

    weights = last_layer  # append bias and weights from last fc layer together

    # features0: (batch_size, snippet_length, feature_dim)
    # weights: (population_size, action_size, feature_dim)
    # goal -> (population_size, batch_size, snippet_length, action_size)

    unnorm_actions0 = torch.einsum('ijk,lmk->lijm', features0, weights)
    unnorm_actions1 = torch.einsum('ijk,lmk->lijm', features1, weights)

    # softmax

    actions0 = unnorm_actions0
    actions1 = unnorm_actions1

    if (action_true0.sum(dim=-1) == 1 & (action_true0.eq(0) | action_true0.eq(1)).all(dim=-1)).all():
        log_probs0 = torch.sum(
            torch.log(torch.sum(actions0 * action_true0[None], dim=-1) + 1e-8), dim=-1)
        log_probs1 = torch.sum(
            torch.log(torch.sum(actions1 * action_true1[None], dim=-1) + 1e-8), dim=-1)
    else:
        log_probs0 = -torch.square(actions0 - action_true0).sum(dim=-1)
        log_probs1 = -torch.square(actions1 - action_true1).sum(dim=-1)

    # if more than 2 dims
    if len(log_probs0.shape) > 2:
        log_probs0 = torch.sum(log_probs0, dim=-1)
        log_probs1 = torch.sum(log_probs1, dim=-1)

    adv0 = alpha * log_probs0
    adv1 = alpha * log_probs1

    # exp_adv0 = np.exp(adv0)
    # exp_adv1 = np.exp(adv1)
    exp_adv0 = torch.exp(adv0)
    exp_adv1 = torch.exp(adv1)

    # score according to Bradley terry model
    scores = torch.where(prefs == 0, exp_adv0, exp_adv1) / \
        (exp_adv0 + exp_adv1)

    return scores

# Take as input a compressed pretrained network or run T_REX before hand
# Then run MCMC and save posterior chain


def calc_lexicase_scores(last_layer, features0, features1, prefs, bt=None):
    """use (i,j) indices and precomputed feature counts to do faster pairwise ranking loss"""

    if bt == None:
        popsize = last_layer.shape[0]

        weights = last_layer  # append bias and weights from last fc layer together

        snippet = None
        if len(features0.shape) == 3:  # must have a trajectory batch
            snippet = features0.shape[1]
            features0 = features0.reshape(-1, features0.shape[-1])
            features1 = features1.reshape(-1, features1.shape[-1])

        returns0 = weights @ features0.T
        returns1 = weights @ features1.T

        # sum over snippet dim if exists
        if snippet != None:
            returns0 = returns0.reshape(popsize, snippet, -1)
            returns1 = returns1.reshape(popsize, snippet, -1)
            returns0 = torch.sum(returns0, dim=1)
            returns1 = torch.sum(returns1, dim=1)

        # scores are where returns1 > returns0
        outputs = (returns1 > returns0).to(device=last_layer.device)
        # pref is 1 if returns1 > returns0
        # so, score is 1 if pref is 1 and returns1 > returns0
        scores = torch.logical_not((torch.logical_xor(prefs, outputs)))
        return scores
    else:
        return calc_prob_scores(last_layer, features0, features1, prefs, bt)


def calc_prob_scores(last_layer, features0, features1, prefs, bt=1):
    """use (i,j) indices and precomputed feature counts to do faster pairwise ranking loss, bt is boltzmann temperature"""

    popsize = last_layer.shape[0]

    weights = last_layer  # append bias and weights from last fc layer together

    returns0 = weights @ features0.T
    returns0 = returns0 / bt
    returns1 = weights @ features1.T
    returns1 = returns1 / bt

    prefs = prefs.cpu().detach().numpy()

    outputs = np.zeros((popsize, len(prefs)))
    outputs = np.where(
        prefs == 0,
        np.exp(returns0) / (np.exp(returns0) + np.exp(returns1)),
        np.exp(returns1) / (np.exp(returns0) + np.exp(returns1)),
    )

    # use luce shephard bradley terry to get the outputs
    scores = outputs
    return scores


def popl_search(
    reward_model,  # could be None, if None, we randomly init
    pairwise_prefs,
    features0,
    features1,
    popsize,
    num_steps,
    step_stdev,
    normalize,
    downsample_level,
    elitism,
    bt,
    mutation_fn,
):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    downsample_level = int(downsample_level)
    size_of_ds = len(pairwise_prefs) // downsample_level

    print(f"reward_mdoel: {reward_model}")

    # create random numbers between -1 and 1 of shape: ((popsize, len(demo_cnts[0]))
    # population = np.random.uniform(-1, 1, (popsize, features0.shape[-1]))
    # convert to torch
    # random nums between -1 and 1
    if reward_model is None:
        population = (torch.rand(popsize, features0.shape[-1]) * 2) - 1
        population = population.to(features0.device)
    else:
        # take the last layer weights, and duplicate them to popsize
        population = reward_model.last_layer.weight.data.repeat(popsize, 1)

        # normalize the weight vector to have unit 2-norm
    if normalize:
        population = population / torch.norm(population, dim=1)[:, None]

    rand_indices = torch.randperm(len(pairwise_prefs))

    ds_indices = torch.arange(len(pairwise_prefs))[rand_indices][:size_of_ds]

    scores = calc_lexicase_scores(
        population,
        features0[ds_indices],
        features1[ds_indices],
        pairwise_prefs[ds_indices],
        bt=bt,
    )

    real_scores = calc_lexicase_scores(
        population, features0, features1, pairwise_prefs, bt=None
    )  # don't use BT to get boolean scores

    print(
        f"score: {torch.max(torch.sum(real_scores, dim=1))}/{real_scores.shape[1]}")

    for i in trange(num_steps, desc="POPL Steps"):
        # take a proposal step
        selected = select_from_scores(
            scores, elitism=elitism, epsilon=False, beta=bt
        )

        new_pop = population[selected[int(elitism):]]

        # mutate pop with gaussian noise
        with torch.no_grad():
            new_pop = mutation_fn(new_pop, step_stdev)

        if elitism:
            new_pop = torch.cat(
                (torch.unsqueeze(population[selected[0]], dim=0), new_pop), dim=0
            )  # elitism

        population = new_pop

        # pick new down-sample of preferences for next generation
        rand_indices = torch.randperm(len(pairwise_prefs))
        ds_indices = torch.arange(len(pairwise_prefs))[
            rand_indices][:size_of_ds]

    scores = calc_lexicase_scores(
        population,
        features0[ds_indices],
        features1[ds_indices],
        pairwise_prefs[ds_indices],
        bt=bt,
    )
    # project to norm 2 ball
    if normalize:
        population = population / torch.norm(population, dim=1)[:, None]

    return population, scores, None  # , best_scores


# instead of searching for a reward function, we can instead search directly for the optimal policy for each group
def popl_policy_search(policy_model, pairwise_prefs, features0, features1, actions0, actions1, popsize, num_steps, step_stdev, normalize, mutation_fn, downsample_level=1, elitism=False, alpha=None):
    # randomly initialize if no policy model to start from
    if policy_model is None:
        population = (torch.rand(popsize, features0.shape[-1]) * 2) - 1
        population = population.to(features0.device)
    #check if it is a set of last layers
    elif isinstance(policy_model, torch.Tensor):
        population = policy_model
    else: #must be a model
        # take the last layer weights, and duplicate them to popsize
        single_layer = policy_model.last_layer.weight.data.unsqueeze(0)
        population = policy_model.last_layer.weight.data.repeat(popsize, 1, 1)

    downsample_level = int(downsample_level)
    size_of_ds = len(pairwise_prefs) // downsample_level
    # normalize the weight vector to have unit 2-norm
    if normalize:
        population = population / torch.norm(population, dim=1)[:, None]

    # get down-sample of preferences
    rand_indices = torch.randperm(len(pairwise_prefs))
    ds_indices = torch.arange(len(pairwise_prefs))[rand_indices][:size_of_ds]

    scores = calc_popl_policy_score(population, features0[ds_indices], features1[ds_indices],
                                    actions0[ds_indices], actions1[ds_indices], pairwise_prefs[ds_indices], alpha=alpha)
    real_scores = calc_popl_policy_score(
        population, features0, features1, actions0, actions1, pairwise_prefs, alpha=None)

    print(f"Best_score_starting: {torch.max(torch.sum(real_scores, dim=1))}/{real_scores.shape[1]}")

    for i in trange(num_steps, desc="POPL Policy Search Steps"):
        selected = select_from_scores(
            scores, elitism=elitism, epsilon=False, num_to_select=popsize, beta=alpha)

        new_pop = population[selected[int(elitism):]]

        # mutate pop with gaussian noise
        new_pop = mutation_fn(new_pop, step_stdev)

        if elitism:
            new_pop = torch.cat(
                (torch.unsqueeze(population[selected[0]], dim=0), new_pop), dim=0)

        population = new_pop

        if normalize:
            population = population / \
                torch.norm(population, dim=1)[:, None]

        rand_indices = torch.randperm(len(pairwise_prefs))
        ds_indices = torch.arange(len(pairwise_prefs))[
            rand_indices][:size_of_ds]
        
        # calculate scores for the new population
        scores = calc_popl_policy_score(population, features0[ds_indices], features1[ds_indices],
                                        actions0[ds_indices], actions1[ds_indices], pairwise_prefs[ds_indices], alpha=alpha)

    # one more selection
    scores = calc_popl_policy_score(
        population, features0, features1, actions0, actions1, pairwise_prefs, alpha=None)
    #print(f"scores: {scores.shape}")
    return population, scores, None  # , best_scores


def select_one(population, features0, features1, pairwise_prefs):
    scores = calc_lexicase_scores(
        population, features0, features1, pairwise_prefs)
    selected = select_from_scores(scores, elitism=False, epsilon=False)
    return population[selected[0]], scores[selected[0]]


def select_one_best(population, features0, features1, pairwise_prefs):
    max_score = 0
    best_ind = None
    for i in range(len(population)):
        last_layer = population[i]

        output0 = features0 @ last_layer
        output1 = features1 @ last_layer

        pred_output = (output1 > output0).to(device=last_layer.device)
        scores = torch.logical_not(
            (torch.logical_xor(pairwise_prefs, pred_output)))
        score = torch.sum(scores).item()
        if score > max_score:
            max_score = score
            best_ind = i

    return population[best_ind], max_score


def generate_feature_counts_from_model(demos, reward_net, n, device):
    feature_cnts = torch.zeros(len(demos), n)  # no bias
    for i in range(len(demos)):
        traj = np.array(demos[i])
        # print(traj)
        traj = torch.from_numpy(traj).float().to(device)
        # print(len(trajectory))
        feature_cnts[i, :] = reward_net(traj)

    return feature_cnts.to(device)


def generate_feature_counts_from_state(demos, n, device):
    feature_cnts = torch.zeros(len(demos), n)  # no bias

    for i in range(len(demos)):
        traj = np.array(demos[i])
        # print(traj)
        traj = torch.from_numpy(traj).float().to(device)
        # print(len(trajectory))
        feature_cnts[i] = torch.sum(traj, axis=0)

    print(f"feautre counts: {feature_cnts}")
    print(f"feature counts shape: {feature_cnts.shape}")
    return feature_cnts.to(device)
