import numpy as np
import torch

from utils.generate_nn import (
    generate_MLP,
    TransformerIndSA,
)
from utils.utils_fn import segment_data_fn


def log_sigmoid_loss(x):
    log_sigmoid = torch.log(torch.sigmoid(x) + 1e-8)
    return -log_sigmoid


class AdvantageFn:
    def __init__(
        self,
        lr,
        env,
        agent,
        device,
        model_params,
        reward_update_epochs,
        reward_update_steps,
        train_pref_batch_size,
        logging_update_steps=1,
        use_agent_mean=False,
        weight_decay=1e-3,
        max_num_pref=0,
        use_init_reward_as_null=False,
        pref_train_acc_min=0.999,
    ) -> None:
        # Note Reward model to be used as advantage function
        self.device = device

        self.obs_dim = env.obs_dim
        self.action_dim = env.action_dim

        self.lr = lr
        self.weight_decay = weight_decay

        self.model_params = model_params

        self.new_model(create_new=False)

        self.train_pref_batch_size = train_pref_batch_size

        self.reward_update_epochs = reward_update_epochs
        self.reward_update_steps = reward_update_steps

        self.logging_update_steps = logging_update_steps

        self.use_agent_mean = use_agent_mean

        self.use_init_reward_as_null = use_init_reward_as_null

        self.max_num_pref = max_num_pref

        self.current_agent = agent

        self.pref_loss_fn = log_sigmoid_loss

        self.pref_train_acc_min = pref_train_acc_min
        self.total_train_epochs = 0

        self.reward_train_step = -1

    def new_model(self, create_new=False):
        if self.model_params["type"] == "mlp":
            generate_model = generate_MLP
            net_input_dim = self.obs_dim + self.action_dim
            model = generate_model(
                in_dim=net_input_dim,
                out_dim=1,
                width=self.model_params.get("width", 128),
                n_layers=self.model_params.get("n_layers", 3),
                final_activation=self.model_params.get("final_activation", None),
                add_norm=self.model_params.get("net_norm", "ln"),
                device=self.device,
            )
        elif self.model_params["type"] == "transformer_ind_sa":
            model = TransformerIndSA(
                obs_dim=self.obs_dim,
                action_dim=self.action_dim,
                out_dim=1,
                num_layers=self.model_params.get("num_layers", 3),
                num_heads=self.model_params.get("num_heads", 4),
                hidden_dim=self.model_params.get("hidden_dim", 256),
                dropout=self.model_params.get("dropout", 0.1),
                out_layer_hidden_dim=self.model_params.get("out_layer_hidden_dim", 128),
                out_layer_n_hidden=self.model_params.get("out_layer_n_hidden", 2),
                max_len=self.model_params.get("max_len", 64),
                pos_encoding=self.model_params.get("pos_encoding", "learned"),
                device=self.device,
            )
        else:
            raise ValueError(
                f"Unknown model type: {self.model_params['type']}. "
                "Available options: mlp, vanilla_transformer."
            )

        opt = torch.optim.AdamW(
            model.parameters(), lr=self.lr, weight_decay=self.weight_decay
        )

        if not create_new:
            self.model = model
            self.opt = opt
        return model, opt

    def train_pref(self, data, agent):
        if data.is_dense_pref:
            return self.train_pref_dense(data, agent)
        else:
            return self.train_pref_sparse(data, agent)

    def train_reward(self, feedback_collector, agent, replay_buffer, logger, step):

        if step % self.reward_update_steps != 0:
            return

        self.reward_train_step = step

        self.model.train()
        pref_train_acc = 0.0
        pref_loss_mean = 0.0
        pref_data_stats = {}

        epoch = 0

        while epoch < self.reward_update_epochs:

            pref_train_acc, pref_loss, pref_data_stats = self.train_pref(
                data=feedback_collector, agent=agent
            )
            pref_loss_mean = pref_loss[-1]
            if pref_train_acc >= self.pref_train_acc_min:
                print(f"Preference Accuracy >= {self.pref_train_acc_min}")
                break

            log_step = self.total_train_epochs + +epoch
            logger.log(
                "pref_reward_training",
                pref_data_stats,
                step=log_step,
                step_name="reward_epoch/step",
            )

            epoch += 1

        self.total_train_epochs += epoch
        pref_data_stats["num_epochs"] = epoch

        if step % (self.logging_update_steps * self.reward_update_steps) == 0:
            pref_reward_training_metrics = {
                "pref_train_acc": pref_train_acc,
                "pref_loss_mean": pref_loss_mean,
                **pref_data_stats,
            }

            logger.log("adv_training", pref_reward_training_metrics, step)

    def reward(self, obs, act):
        self.model.eval()
        if len(obs.shape) == 1:
            obs = obs.unsqueeze(0)
        if len(act.shape) == 1:
            act = act.unsqueeze(0)
        with torch.no_grad():
            input = torch.cat((obs, act), dim=-1)
            reward_value = self.model(input)
            if reward_value.shape[0] == 1:
                reward_value = reward_value.squeeze(0)
        return reward_value

    def reward_policy_update(self, obs, act, return_sa=True):
        self.model.eval()
        if len(obs.shape) == 1:
            obs = obs.unsqueeze(0)
        if len(act.shape) == 1:
            act = act.unsqueeze(0)
        with torch.no_grad():
            obs_act = torch.cat((obs, act), dim=-1)
            reward_value = self.model(obs_act)
            if reward_value.shape[0] == 1:
                reward_value = reward_value.squeeze(0)
        self.model.train()
        if return_sa:
            reward_flat = reward_value.view(-1, reward_value.shape[-1])
            return reward_flat
        return reward_value

    def train_pref_dense(self, data, agent):
        losses = []
        accuracy = 0.0

        total_num_preferences = data.num_preferences()
        if total_num_preferences == 0:
            losses = [0.0]
            return accuracy, losses, {}

        self.model.train()
        self.current_agent = agent

        id_pref1_all = np.asarray(data.dense_pref_pairs1)
        id_pref2_all = np.asarray(data.dense_pref_pairs2)

        _, _, segment, _ = segment_data_fn(
            data=data.demonstrations,
            obs_dim=self.obs_dim,
        )

        if self.train_pref_batch_size <= 0:
            num_epoch = 1
            id_pref1_batches = [id_pref1_all]
            id_pref2_batches = [id_pref2_all]
        else:
            num_epoch = int(np.ceil(total_num_preferences / self.train_pref_batch_size))
            total_batch_index = np.random.permutation(total_num_preferences)
            id_pref1_batches = []
            id_pref2_batches = []
            for epoch in range(num_epoch):
                start_index = epoch * self.train_pref_batch_size
                last_index = min(
                    (epoch + 1) * self.train_pref_batch_size, total_num_preferences
                )
                random_index = total_batch_index[start_index:last_index]
                id_pref1_batches.append(id_pref1_all[random_index])
                id_pref2_batches.append(id_pref2_all[random_index])

        for epoch in range(num_epoch):
            segment = segment.to(self.device)

            reward_d = self.model(segment)

            reward_full = reward_d.sum(axis=1)

            id_pref1 = id_pref1_batches[epoch]
            id_pref2 = id_pref2_batches[epoch]

            curr_loss = self.pref_loss_fn(
                (reward_full[id_pref1] - reward_full[id_pref2])
            )
            curr_loss = curr_loss.mean()
            curr_loss_logging = curr_loss.item()

            self.opt.zero_grad()
            curr_loss.backward()
            self.opt.step()

            losses.append(curr_loss.item())

            accuracy += (
                ((reward_full[id_pref1] - reward_full[id_pref2]) > 0).sum().item()
            )

        pref_data_stats = {
            "accuracy": accuracy / total_num_preferences,
            "pref_loss": np.mean(losses),
            "curr_loss_logging": curr_loss_logging,
        }
        return (
            accuracy / total_num_preferences,
            losses,
            pref_data_stats,
        )

    def train_pref_sparse(self, data, agent):

        losses = []
        accuracy = 0.0
        mean_r_1 = 0.0
        mean_r_2 = 0.0

        total_num_preferences = data.num_preferences()

        if total_num_preferences == 0:
            losses = [0.0]
            return accuracy, losses, {}

        total_batch_index = np.random.permutation(total_num_preferences)

        num_epochs = int(np.ceil(total_num_preferences / self.train_pref_batch_size))
        num_preferences = 0

        self.model.train()
        self.current_agent = agent
        for epoch in range(num_epochs):
            self.opt.zero_grad()

            last_index = (epoch + 1) * self.train_pref_batch_size
            if last_index > total_num_preferences:
                last_index = total_num_preferences

            # get random batch
            idxs = total_batch_index[epoch * self.train_pref_batch_size : last_index]

            sa_t_1, sa_t_2, labels = data.get_preference_data(idxs=idxs)
            sa_t_1 = sa_t_1.to(self.device)
            sa_t_2 = sa_t_2.to(self.device)
            labels = labels.to(self.device)
            if labels.dim() == 1:
                labels = labels.unsqueeze(dim=-1)

            _, _, segment_1, sa_1 = segment_data_fn(data=sa_t_1, obs_dim=self.obs_dim)
            _, _, segment_2, sa_2 = segment_data_fn(data=sa_t_2, obs_dim=self.obs_dim)

            num_preferences += labels.shape[0]

            r_hat1_d = self.model(segment_1)
            r1_full = r_hat1_d.sum(axis=1)

            r_hat2_d = self.model(segment_2)
            r2_full = r_hat2_d.sum(axis=1)

            delta_r = r1_full - r2_full

            labels_scaled = -2 * labels + 1  ### label=0->r1 better, label=1->r2 better
            if labels_scaled.dim() == 1:
                labels_scaled = torch.unsqueeze(labels_scaled, dim=-1)

            if delta_r.dim() == 2:
                predicted = torch.squeeze(torch.sign(delta_r), dim=-1)
            elif delta_r.dim() == 1:
                predicted = torch.sign(delta_r)

            if labels_scaled.dim() == 2:
                labels_scaled = torch.squeeze(labels_scaled, dim=-1)

            correct = (predicted == labels_scaled.to(predicted.device)).sum().item()

            accuracy += correct

            ones = torch.ones_like(labels)
            loss_equal = (ones - labels) * self.pref_loss_fn(
                delta_r
            ) + labels * self.pref_loss_fn(-delta_r)
            curr_loss = loss_equal.mean()

            curr_loss.backward()
            self.opt.step()

            losses.append(curr_loss.item())

        accuracy = accuracy / num_preferences

        pref_data_stats = {
            "mean_r_1": mean_r_1 / num_preferences,
            "mean_r_2": mean_r_2 / num_preferences,
            "accuracy": accuracy,
        }

        if data.num_validation_pref > 0:
            self.model.eval()
            with torch.no_grad():
                sa_t_1_val = data.val_pref1.to(self.device)
                sa_t_2_val = data.val_pref2.to(self.device)

                _, _, segment_1_val, sa_1_val = segment_data_fn(
                    data=sa_t_1_val, obs_dim=self.obs_dim
                )
                _, _, segment_2_val, sa_2_val = segment_data_fn(
                    data=sa_t_2_val, obs_dim=self.obs_dim
                )

                r_hat1_d_val = self.model(segment_1_val)
                r_hat1_val = r_hat1_d_val.mean(axis=1)

                r_hat2_d_val = self.model(segment_2_val)
                r_hat2_val = r_hat2_d_val.mean(axis=1)
                delta_r_val = r_hat1_val - r_hat2_val

                labels_val = data.val_labels
                labels_scaled_val = -2 * labels_val + 1
                if labels_scaled_val.dim() == 1:
                    labels_scaled_val = torch.unsqueeze(labels_scaled_val, dim=-1)
                predicted_val = torch.sign(delta_r_val)
                correct_val = (
                    (
                        predicted_val.to(device=labels_scaled_val.device)
                        == labels_scaled_val
                    )
                    .sum()
                    .item()
                )
                accuracy_val = correct_val / data.num_validation_pref
                pref_data_stats["accuracy_val"] = accuracy_val

        return (
            accuracy,
            losses,
            pref_data_stats,
        )
