import numpy as np
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
from utils.logger import Logger

def index_batch(offline_batch, indices):
    indexed = {}
    for key in offline_batch.keys():
        indexed[key] = offline_batch[key][indices, ...]
    return indexed

    
def gen_net(in_size=1, out_size=1, H=256, n_layers=3, activation='tanh'):
    net = []
    for i in range(n_layers):
        net.append(nn.Linear(in_size, H))
        net.append(nn.LeakyReLU())
        in_size = H
    net.append(nn.Linear(in_size, out_size))
    if activation == 'tanh':
        net.append(nn.Tanh())
    elif activation == 'sig':
        net.append(nn.Sigmoid())
    else:
        pass

    return net


class RewardModel(object):
    def __init__(self, args, observation_dim, action_dim, num_reward_ensemble=3, lr=3e-4, activation="tanh",
                 device="cuda"):
        self.args = args
        self.observation_dim = observation_dim  # state: env.observation_space.shape[0]
        self.action_dim = action_dim  # state: env.action_space.shape[0]
        self.num_reward_ensemble = num_reward_ensemble  # num_reward_ensemble
        self.lr = lr  # learning rate
        self.device = torch.device(device)

        # build network
        self.opt = None
        self.activation = activation
        self.ensemble = []
        self.paramlst = []
        self.construct_ensemble()

    def construct_ensemble(self):
        for i in range(self.num_reward_ensemble):
            model = nn.Sequential(*gen_net(in_size=self.observation_dim + self.action_dim,
                                           out_size=1, H=256, n_layers=3,
                                           activation=self.activation)).float().to(self.device)
            self.ensemble.append(model)
            self.paramlst.extend(model.parameters())

        self.opt = torch.optim.Adam(self.paramlst, lr=self.lr)

    def save_model(self, path):
        state_dicts = [model.state_dict() for model in self.ensemble]
        torch.save(state_dicts, os.path.join(path, "reward.pth"))

    def load_model(self, path):
        state_dicts = torch.load(os.path.join(path, "reward.pth"), map_location=self.device)
        for model, state_dict in zip(self.ensemble, state_dicts):
            model.load_state_dict(state_dict)
            model.to(self.device)

    def pretrain(self, init_pref_real_dataset, n_epochs=100,  logger = Logger, batch_size=64):
        pref_real_dataset = self.construct_offline_data(init_pref_real_dataset,0, self.args.reward_data_size)
        pref_eval_dataset = self.construct_offline_data(init_pref_real_dataset,self.args.reward_data_size, 2*self.args.reward_data_size)
        
        interval = int(pref_real_dataset["observations_1"].shape[0] / batch_size) + 1
        logger.log("Start Pretraining Reward:")

        for epoch in range(1, n_epochs + 1):
            ensemble_losses = [[] for _ in range(self.num_reward_ensemble)]
            ensemble_train_acc = [[] for _ in range(self.num_reward_ensemble)]
            ensemble_test_acc = [[] for _ in range(self.num_reward_ensemble)]

            offline_batch_shuffled_idx = []
            for _ in range(self.num_reward_ensemble):
                offline_batch_shuffled_idx.append(np.random.permutation(pref_real_dataset["observations_1"].shape[0]))

            for i in range(interval):
                self.opt.zero_grad()
                total_loss = 0
                start_pt = i * batch_size
                end_pt = min((i + 1) * batch_size, pref_real_dataset["observations_1"].shape[0])
                for member in range(self.num_reward_ensemble):
                    # get offline_batch
                    offline_batch = index_batch(pref_real_dataset, offline_batch_shuffled_idx[member][start_pt:end_pt])
                    # compute loss
                    curr_loss, _ = self.calculate_loss(offline_batch, member)
                    total_loss += curr_loss
                    ensemble_losses[member].append(curr_loss.item())
                total_loss.backward()
                self.opt.step()

            # evaluation
            for member in range(self.num_reward_ensemble):
                _, test_correct = self.calculate_loss(pref_eval_dataset, member)
                _, train_correct = self.calculate_loss(pref_real_dataset, member)
                ensemble_test_acc[member].append(test_correct)
                ensemble_train_acc[member].append(train_correct) 

            logger.logkv("1_train_loss",np.mean(ensemble_losses))
            logger.logkv("2_train_acc", np.mean(ensemble_train_acc))
            logger.logkv("3_test_acc", np.mean(ensemble_test_acc))
            logger.set_timestep(epoch)
            logger.dumpkvs(exclude=["policy_training_progress","transition_training_progress"])
            logger.log("loss:{} , accuarcy: {}".format(np.mean(ensemble_losses), np.mean(ensemble_test_acc)))

        self.save_model(logger.model_dir)


    def train(self, init_pref_real_dataset, init_pref_fake_dataset, fake_ratio, n_epochs=200,  logger = Logger, batch_size=256):
        # processing dataset
        pref_real_dataset = self.construct_offline_data(init_pref_real_dataset,0,self.args.reward_data_size)
        pref_eval_dataset = self.construct_offline_data(init_pref_real_dataset,self.args.reward_data_size, int(2*self.args.reward_data_size))

        # uncertainty-aware-based selecting samples
        if self.args.select_data:
            mean_probs, std_probs =self.get_rank_probability(init_pref_fake_dataset)
            fake_index = np.array(np.where((mean_probs >= self.args.mean_probs) & (std_probs<=self.args.std_probs))).reshape(-1)
            pref_fake_dataset = index_batch(init_pref_fake_dataset, fake_index)
        else:
            pref_fake_dataset = init_pref_fake_dataset

        # data size
        fake_batch_size = int(fake_ratio*batch_size)
        real_batch_size = batch_size - fake_batch_size
        interval = int(pref_fake_dataset["observations_1"].shape[0] / fake_batch_size) + 1

        logger.log("Start training Reward:")

        for epoch in range(1, n_epochs + 1):
            ensemble_losses = [[] for _ in range(self.num_reward_ensemble)]
            ensemble_acc = [[] for _ in range(self.num_reward_ensemble)]
            fake_losses = [[] for _ in range(self.num_reward_ensemble)]
            real_losses = [[] for _ in range(self.num_reward_ensemble)]
            ensemble_test_acc = [[] for _ in range(self.num_reward_ensemble)]

            batch_shuffled_idx = []
            for _ in range(self.num_reward_ensemble):
                batch_shuffled_idx.append(np.random.permutation(pref_fake_dataset["observations_1"].shape[0]))

            for i in range(interval):
                self.opt.zero_grad()
                total_loss = 0
                start_pt = i * fake_batch_size
                end_pt = min((i + 1) * fake_batch_size, pref_fake_dataset["observations_1"].shape[0])
                for member in range(self.num_reward_ensemble):
                    # get batch
                    fake_batch = index_batch(pref_fake_dataset, batch_shuffled_idx[member][start_pt:end_pt])
                    real_batch_indexes = np.random.randint(0, pref_real_dataset["observations_1"].shape[0], size=real_batch_size)
                    real_batch = index_batch(pref_real_dataset, real_batch_indexes)

                    # compute loss
                    real_loss, _ = self.calculate_loss(real_batch, member)
                    fake_loss, _ = self.calculate_loss(fake_batch, member)

                    loss = real_loss + fake_loss
                    total_loss += loss

                    ensemble_losses[member].append(loss.item())
                    real_losses[member].append(real_loss.item())
                    fake_losses[member].append(fake_loss.item())

                total_loss.backward()
                self.opt.step()

            # evaluation
            for member in range(self.num_reward_ensemble):
                _, test_correct = self.calculate_loss(pref_eval_dataset, member)
                _, train_correct = self.calculate_loss(pref_real_dataset, member)
                ensemble_test_acc[member].append(test_correct)
                ensemble_acc[member].append(train_correct) 

            logger.logkv("1_train_loss",np.mean(ensemble_losses))
            logger.logkv("5_train_real_loss", np.mean(real_losses))
            logger.logkv("6_train_fake_loss", np.mean(fake_losses))
            logger.logkv("2_train_acc", np.mean(ensemble_acc))
            logger.logkv("3_test_acc", np.mean(ensemble_test_acc))
            logger.logkv("4_train_num_data",pref_fake_dataset["observations_1"].shape[0])
            logger.set_timestep(epoch)
            logger.dumpkvs(exclude=["policy_training_progress","transition_training_progress"])

            # # early stop
            # if np.mean(ensemble_acc) > 0.99999 and "antmaze" not in self.args.task:
            #     break

            logger.log("loss:{} , accuarcy: {}".format(np.mean(ensemble_losses), np.mean(ensemble_test_acc)))

        self.save_model(logger.model_dir)

    def get_reward_offline_batch(self, x):
        r_hats = []
        for member in range(self.num_reward_ensemble):
            r_hats.append(self.r_hat_member(x, member=member).detach().cpu().numpy())
        r_hats = np.array(r_hats)
        # r_mean = np.mean(r_hats, axis=0)
        # r_std = np.std(r_hats, axis=0)
        # r = r_mean + self.args.weight*r_std
        return np.mean(r_hats, axis=0)
    
    def calculate_loss(self, batch, member):
        # get batch
        obs_1 = batch['observations_1']  # batch_size * len_query * obs_dim
        act_1 = batch['actions_1']  # batch_size * len_query * action_dim
        obs_2 = batch['observations_2']
        act_2 = batch['actions_2']
        labels = batch['labels']  # batch_size * 2 (one-hot, for equal label)
        s_a_1 = np.concatenate([obs_1, act_1], axis=-1)
        s_a_2 = np.concatenate([obs_2, act_2], axis=-1)

        # get comparable labels
        comparable_indices = np.where((labels != [0.5, 0.5]).any(axis=1))[0]
        comparable_labels = torch.from_numpy(np.argmax(labels, axis=1)).to(self.device)

        # get logits
        r_hat1 = self.r_hat_member(s_a_1, member)  # batch_size * len_query * 1
        r_hat2 = self.r_hat_member(s_a_2, member)
        r_hat1 = r_hat1.sum(axis=1)  # batch_size * 1
        r_hat2 = r_hat2.sum(axis=1)
        r_hat = torch.cat([r_hat1, r_hat2], axis=1)  # batch_size * 2

        # get labels
        labels = torch.from_numpy(labels).to(self.device)

        # compute loss
        curr_loss = self.softXEnt_loss(r_hat, labels)

        # compute acc
        _, predicted = torch.max(r_hat.data, 1)

        if not len(comparable_indices):
            correct = 0.7  # TODO, for exception
        else:
            correct = (predicted[comparable_indices] == comparable_labels[comparable_indices]).sum().item() / len(
                comparable_indices)

        return curr_loss, correct

    def calculate_uncertainty(self, batch):
        # get batch
        seq_penalty_1 = batch['seq_penalty_1']
        seq_penalty_2 = batch['seq_penalty_2']

        uncertainty = (seq_penalty_1+seq_penalty_2)/2
        uncertainty_mean = np.mean(uncertainty)
        return torch.tensor(uncertainty_mean,device=self.device)
    
    def select_data(self,pref_fake_dataset):
         # uncertainty-aware-based selecting samples
        mean_probs, std_probs =self.get_rank_probability(pref_fake_dataset)
        fake_index = np.array(np.where((mean_probs >= self.args.mean_probs) & (std_probs<=self.args.std_probs))).reshape(-1)
        selected_fake_dataset = index_batch(pref_fake_dataset, fake_index)
        return selected_fake_dataset

    def construct_offline_data(self, dataset, start,end):
        data={}
        for key in dataset:
            data[key] = dataset[key][start:end]
        return data
    
    def combine_dataset(self, dataset1, dataset2):
        data={}
        for key in dataset1:
            data[key] = np.append(dataset1[key], dataset2[key],axis=1)
        return data

    def r_hat_member(self, x, member):
        return self.ensemble[member](torch.from_numpy(x).float().to(self.device))


    def softXEnt_loss(self, input, target):
        logprobs = nn.functional.log_softmax(input, dim=1)
        return -(target * logprobs).sum() / input.shape[0]

  
    def get_rank_probability(self, batch):
        # get probability x_1 > x_2
        # get batch
        obs_1 = batch['observations_1']  # batch_size * len_query * obs_dim
        act_1 = batch['actions_1']  # batch_size * len_query * action_dim
        obs_2 = batch['observations_2']
        act_2 = batch['actions_2']
        labels = batch['labels']

        s_a_1 = np.concatenate([obs_1, act_1], axis=-1)
        s_a_2 = np.concatenate([obs_2, act_2], axis=-1)
        probs = []

        for member in range(self.num_reward_ensemble):
            probs.append(self.p_hat_member(s_a_1, s_a_2, member=member).cpu().numpy())
        probs = np.array(probs)
        mean_probs = np.mean(probs, axis=0)
        std_probs = np.std(probs, axis=0)

        # get comparable labels
        comparable_indices = np.where((labels != [0.5, 0.5]).any(axis=1))[0]
        comparable_labels = np.argmax(labels, axis=1)
        mean_probs[comparable_indices] = mean_probs[comparable_indices]*[comparable_labels[comparable_indices]==0] \
            + (1-mean_probs[comparable_indices])*[comparable_labels[comparable_indices]==1]

        return mean_probs, std_probs
    
    def get_entropy(self, batch):
        # get probability x_1 > x_2
        obs_1 = batch['observations_1']  # batch_size * len_query * obs_dim
        act_1 = batch['actions_1']  # batch_size * len_query * action_dim
        obs_2 = batch['observations_2']
        act_2 = batch['actions_2']

        s_a_1 = np.concatenate([obs_1, act_1], axis=-1)
        s_a_2 = np.concatenate([obs_2, act_2], axis=-1)
        probs = []
        for member in range(self.num_reward_ensemble):
            probs.append(self.p_hat_entropy(s_a_1, s_a_2, member=member).cpu().numpy())
        probs = np.array(probs)
        return np.mean(probs, axis=0), np.std(probs, axis=0)
    

    def p_hat_member(self, x_1, x_2, member=-1):
        # softmaxing to get the probabilities according to eqn 1
        with torch.no_grad():
            r_hat1 = self.r_hat_member(x_1, member=member)
            r_hat2 = self.r_hat_member(x_2, member=member)
            # 求和
            r_hat1 = r_hat1.sum(axis=1)
            r_hat2 = r_hat2.sum(axis=1)
            r_hat = torch.cat([r_hat1, r_hat2], axis=-1)
        
        # taking 0 index for probability x_0 > x_1，轨迹0优于轨迹1点概率
        return F.softmax(r_hat, dim=-1)[:,0]
    
    def p_hat_entropy(self, x_1, x_2, member=-1):
        # softmaxing to get the probabilities according to eqn 1
        with torch.no_grad():
            r_hat1 = self.r_hat_member(x_1, member=member)
            r_hat2 = self.r_hat_member(x_2, member=member)
            r_hat1 = r_hat1.sum(axis=1)
            r_hat2 = r_hat2.sum(axis=1)
            r_hat = torch.cat([r_hat1, r_hat2], axis=-1)
        
        ent = F.softmax(r_hat, dim=-1) * F.log_softmax(r_hat, dim=-1)
        ent = ent.sum(axis=-1).abs()
        return ent