import torch
import matplotlib.pyplot as plt
import itertools
import numpy as np

def exploration(env_sampler, memory, num_steps):
    for i in range(num_steps):
        cur_state, action, next_state, reward, done, info = env_sampler.sample(None, exploration=True)
        memory.push(cur_state, action, reward, next_state, done)
    env_sampler.current_state = None
    env_sampler.path_length = 0
    print('Exploration completed')


def evaluation(args, env_sampler, agent):
    returns = []
    land_right = 0
    total = 0
    for i in range(5):
        env_sampler.current_state = None
        env_sampler.path_length = 0
        done = False
        test_step = 0
        G = 0
        while (not done) and (test_step != args.max_path_length):
            cur_state, action, next_state, reward, done, info = env_sampler.sample(agent, eval_t=True, clipping = False)
            if args.env_name == 'HalfCheetah-v4':
                if info['x_position'] < -3:
                    land_right += 1
            if args.env_name == 'Swimmer-v4':
                if info['x_position'] > 0.5:
                    land_right += 1

            G += reward
            test_step += 1
        total += test_step
        returns.append(G)
    mean_return = np.mean(returns)
    return mean_return, land_right/total

def entropic(action_value_dist, agent):
    Z = action_value_dist * torch.exp((agent.V_RANGE - 100.0) / agent.beta)
    Z = torch.sum(Z, dim=2)
    return Z

def belief(args, env_sampler, agent):
    env_sampler.current_state = None
    done = False
    test_step = 0
    while (not done) and (test_step != args.max_path_length):
        cur_state, action, next_state, reward, done, info = env_sampler.sample(agent, eval_t=True)
        if test_step % 100 == 0:
            x = torch.tensor(cur_state, dtype=torch.float32, device=agent.device).unsqueeze(0)
            with torch.no_grad():
                plt.title(test_step)
                action_value_dist = agent.pred_net(x).squeeze(0)  # (N_ENVS, N_ACTIONS, N_ATOM)
                plt.plot(agent.V_RANGE_numpy, action_value_dist[0, :].cpu().numpy(), label = 0)
                plt.plot(agent.V_RANGE_numpy, action_value_dist[1, :].cpu().numpy(), label = 1)
                plt.plot(agent.V_RANGE_numpy, action_value_dist[2, :].cpu().numpy(), label = 2)
                plt.plot(agent.V_RANGE_numpy, action_value_dist[3, :].cpu().numpy(), label = 3)
                plt.legend()
                plt.show()
                z = entropic(action_value_dist, agent)
                print(test_step, z)
        if reward == -100.0:
            with torch.no_grad():
                plt.title('Horrible')
                action_value_dist = agent.pred_net(x).squeeze(0)  # (N_ENVS, N_ACTIONS, N_ATOM)
                plt.plot(agent.V_RANGE_numpy, action_value_dist[0,:].cpu().numpy(), label = 0)
                plt.plot(agent.V_RANGE_numpy, action_value_dist[1, :].cpu().numpy(), label = 1)
                plt.plot(agent.V_RANGE_numpy, action_value_dist[2, :].cpu().numpy(), label = 2)
                plt.plot(agent.V_RANGE_numpy, action_value_dist[3, :].cpu().numpy(), label = 3)
                plt.legend()
                plt.show()

        test_step += 1


def plot_variational_variance(args, env_sampler, agent, prior_model):
    env_sampler.current_state = None
    env_sampler.path_length = 0
    done = False
    test_step = 0
    while (not done) and (test_step != args.max_path_length):
        cur_state, action, next_state, reward, done, info = env_sampler.sample(agent, eval_t=True)
        # action_one_hot_batch = np.zeros((4,1))
        # action_one_hot_batch[action,0] = 1
        state_batch = cur_state.reshape((-1,1))
        next_state_batch = next_state.reshape((-1, 1))
        action = action.reshape((-1, 1))
        inputs = np.concatenate((state_batch, action, next_state_batch), axis=0)
        inputs = torch.tensor(inputs, dtype=torch.float32, device=agent.device).squeeze()
        inputs = inputs.unsqueeze(0)
        mean_p, log_var_p = prior_model.model(inputs)
        var_p = torch.exp(log_var_p)
        # print()
        if next_state[0] < -3:
            x = plt.Circle((cur_state[0],test_step),np.sqrt(var_p.cpu().item())/10, color='r', fill=False)
        else:
            x = plt.Circle((cur_state[0], test_step), np.sqrt(var_p.cpu().item())/10, fill=False)
        # ax.add_patch(x)
        # plt.plot(cur_state[0], cur_state[1], 'o')
        plt.gca().add_patch(x)
        test_step += 1
    plt.ylim(0,500)
    plt.xlim(-5,5)
    plt.show()

def evaluate_Q(args, state, policy_net, device):
    i_state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    if args.beta > 0:
        Q_I = torch.max(policy_net(i_state))
    else:
        Q_I = torch.min(policy_net(i_state))
    return Q_I.item()

def plot_action_values(policy_net, device):
    state = []
    for i in np.arange(-2,2.5,0.5):
        for j in np.arange(-0.2,0.21,0.1):
            state.append([i,0,j,0])
    state_batch = torch.FloatTensor(state).to(device)
    Q = policy_net(state_batch)
    Q = Q.cpu().detach().numpy()
    print(Q[0,0])
    plt.imshow(Q[:,0].reshape(-1,9), vmin=0,vmax=5)
    plt.show()

def plot_Z(args, memory, device, policy_net, target_net):
    state_batch, action_batch, next_state_batch, reward_batch, mask_batch = memory.sample(1)
    mask_batch_1 = mask_batch
    mask_batch = ~mask_batch

    state_batch = torch.FloatTensor(state_batch).to(device)
    action_batch = torch.IntTensor(action_batch).to(device).type(torch.int64).unsqueeze(1)
    next_state_batch = torch.FloatTensor(next_state_batch).to(device)
    reward_batch = torch.FloatTensor(reward_batch).to(device).unsqueeze(1)
    mask_batch = torch.FloatTensor(mask_batch).to(device).unsqueeze(1)
    mask_batch_1 = torch.FloatTensor(mask_batch_1).to(device).unsqueeze(1)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    if args.beta > 0:
        with torch.no_grad():
            min_qf_next_target = target_net(next_state_batch).max(1)[0]
            min_qf_next_target = min_qf_next_target.unsqueeze(1)
            min_qf_next_target = torch.clamp_min_(min_qf_next_target, 0)
            min_qf_next_target = torch.pow(min_qf_next_target, args.gamma)
    else:
        with torch.no_grad():
            min_qf_next_target = target_net(next_state_batch).min(1)[0]
            min_qf_next_target = min_qf_next_target.unsqueeze(1)
            min_qf_next_target = torch.clamp_min_(min_qf_next_target, 0.0005)
            min_qf_next_target = torch.pow(min_qf_next_target, args.gamma)
    # print(target_net(next_state_batch))
    print('Run')
    print(state_action_values.item(), torch.exp(reward_batch / args.beta).item(), (mask_batch_1 + mask_batch * min_qf_next_target).item(), mask_batch_1.item(),mask_batch.item())
    print(policy_net(state_batch))

def compute_error(args, memory, trainer):
    state, action, reward, next_state, done = memory.sample(len(memory))
    action_one_hot_batch = np.zeros((action.shape[0], 4))
    action_one_hot_batch[np.arange(action.shape[0]), action] = 1
    test_inputs = np.concatenate((state, action_one_hot_batch, next_state), axis=-1)
    test_labels = reward
    test_size = test_inputs.shape[0]

    with torch.no_grad():
        val_loss = 0
        val_square_loss = 0
        for start_pos in range(0, test_size, 512):
            test_input = torch.from_numpy(test_inputs[start_pos: start_pos + 512]).float().to(
                trainer.device)
            test_label = torch.from_numpy(test_labels[start_pos: start_pos + 512]).float().to(
                trainer.device).unsqueeze(-1)
            mean, logvar = trainer.variational_model.variational_model(test_input)
            # mean, logvar = trainer.reward_model.model(test_input)
            var = torch.exp(logvar)

            mse_loss = torch.pow(mean - test_label, 2) / var
            var_loss = logvar

            loss = torch.mean(mse_loss + var_loss)
            val_loss += loss.item() * test_label.size(0)
            val_square_loss += torch.mean(torch.pow(mean - test_label, 2)).item() * test_label.size(0)

        return (val_loss / test_size), (val_square_loss / test_size)

class EarlyStopper:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

def supervised_training(args, memory, trainer):
    batch_size = 64
    early = EarlyStopper()
    state, action, reward, next_state, done = memory.sample(len(memory))
    action_one_hot_batch = np.zeros((action.shape[0], 4))
    action_one_hot_batch[np.arange(action.shape[0]), action] = 1

    inputs = np.concatenate((state, action_one_hot_batch, next_state), axis=-1)
    labels = reward

    num_holdout = int(inputs.shape[0] * 0.1)
    permutation = np.random.permutation(inputs.shape[0])
    inputs, labels = inputs[permutation], labels[permutation]

    train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]
    holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout]
    train_losses = []
    val_losses = []
    for epoch in itertools.count():
        train_idx = np.random.permutation(train_inputs.shape[0])
        Total_loss = 0.0
        for start_pos in range(0, train_inputs.shape[0], batch_size):
            idx = train_idx[start_pos: start_pos + batch_size]
            train_input = torch.from_numpy(train_inputs[idx]).float().to(trainer.device)
            train_label = torch.from_numpy(train_labels[idx]).float().to(trainer.device).unsqueeze(-1)
            mean, log_var = trainer.reward_model.model(train_input)
            var = torch.exp(log_var)
            # log_var = torch.log(var)
            # print(mean, train_label)
            mse_loss = torch.pow(mean - train_label, 2)/var
            var_loss = log_var

            loss = torch.mean(mse_loss + var_loss)
            # loss = torch.mean(mse_loss)
            Total_loss += loss.item() * train_input.size(0)

            trainer.reward_model.optimizer.zero_grad()
            loss.backward()
            trainer.reward_model.optimizer.step()
        train_losses.append(Total_loss / (inputs.shape[0] - num_holdout))
        with torch.no_grad():
            val_loss = 0
            for start_pos in range(0, holdout_inputs.shape[0], args.batch_size):
                holdout_input = torch.from_numpy(holdout_inputs[start_pos: start_pos + args.batch_size]).float().to(trainer.device)
                holdout_label = torch.from_numpy(holdout_labels[start_pos: start_pos + args.batch_size]).float().to(trainer.device).unsqueeze(-1)
                holdout_mean, holdout_logvar = trainer.reward_model.model(holdout_input)
                holdout_var = torch.exp(holdout_logvar)

                mse_loss = torch.pow(holdout_mean - holdout_label, 2)/holdout_var
                var_loss = holdout_logvar

                loss = torch.mean(mse_loss + var_loss)
                val_loss += loss.item()*holdout_label.size(0)
            val_losses.append(val_loss / num_holdout)
            if early.early_stop(val_loss / num_holdout):

                # print(torch.max(torch.square(holdout_mean - holdout_label)))
                # print(torch.min(torch.square(holdout_mean - holdout_label)))
                # plt.plot(train_losses)
                # plt.plot(val_losses)
                # plt.show()
                break
        # if epoch % 5 == 0:
        #     plt.plot(train_losses)
        #     plt.plot(val_losses)
        #     plt.show()

    for var_model, model in zip(trainer.variational_model.variational_model.parameters(), trainer.reward_model.model.parameters()):
        var_model.data.copy_(model.data)


def supervised_cont(args, memory, trainer):
    batch_size = 64
    early = EarlyStopper()
    state, action, reward, next_state, done = memory.sample(len(memory))
    action_one_hot_batch = np.zeros((action.shape[0], 4))
    action_one_hot_batch[np.arange(action.shape[0]), action] = 1

    inputs = np.concatenate((state, action_one_hot_batch, next_state), axis=-1)
    labels = reward

    num_holdout = int(inputs.shape[0] * 0.1)
    permutation = np.random.permutation(inputs.shape[0])
    inputs, labels = inputs[permutation], labels[permutation]

    train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]
    holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout]
    train_losses = []
    val_losses = []
    for epoch in itertools.count():
        train_idx = np.random.permutation(train_inputs.shape[0])
        Total_loss = 0.0
        for start_pos in range(0, train_inputs.shape[0], batch_size):
            idx = train_idx[start_pos: start_pos + batch_size]
            train_input = torch.from_numpy(train_inputs[idx]).float().to(trainer.device)
            train_label = torch.from_numpy(train_labels[idx]).float().to(trainer.device).unsqueeze(-1)
            mean, log_var = trainer.reward_model.model(train_input)
            var = torch.exp(log_var)
            # log_var = torch.log(var)
            # print(mean, train_label)
            mse_loss = torch.pow(mean - train_label, 2)/var
            var_loss = log_var

            loss = torch.mean(mse_loss + var_loss)
            # loss = torch.mean(mse_loss)
            Total_loss += loss.item() * train_input.size(0)

            trainer.reward_model.optimizer.zero_grad()
            loss.backward()
            trainer.reward_model.optimizer.step()
        train_losses.append(Total_loss / (inputs.shape[0] - num_holdout))
        with torch.no_grad():
            val_loss = 0
            for start_pos in range(0, holdout_inputs.shape[0], args.batch_size):
                holdout_input = torch.from_numpy(holdout_inputs[start_pos: start_pos + args.batch_size]).float().to(trainer.device)
                holdout_label = torch.from_numpy(holdout_labels[start_pos: start_pos + args.batch_size]).float().to(trainer.device).unsqueeze(-1)
                holdout_mean, holdout_logvar = trainer.reward_model.model(holdout_input)
                holdout_var = torch.exp(holdout_logvar)

                mse_loss = torch.pow(holdout_mean - holdout_label, 2)/holdout_var
                var_loss = holdout_logvar

                loss = torch.mean(mse_loss + var_loss)
                val_loss += loss.item()*holdout_label.size(0)
            val_losses.append(val_loss / num_holdout)
            if early.early_stop(val_loss / num_holdout):

                # print(torch.max(torch.square(holdout_mean - holdout_label)))
                # print(torch.min(torch.square(holdout_mean - holdout_label)))
                # plt.plot(train_losses)
                # plt.plot(val_losses)
                # plt.show()
                break
        # if epoch % 5 == 0:
        #     plt.plot(train_losses)
        #     plt.plot(val_losses)
        #     plt.show()

    for var_model, model in zip(trainer.variational_model.variational_model.parameters(), trainer.reward_model.model.parameters()):
        var_model.data.copy_(model.data)

def pre_train_model(args, memory, trainer):
    for i in range(10000):
        loss = trainer.reward_model.optimize_model(args, memory)
        if i % 1000 == 0:
            print(i, loss)

    for var_model, model in zip(trainer.variational_model.variational_model.parameters(), trainer.reward_model.model.parameters()):
        var_model.data.copy_(model.data)

        if epochs % 10 == 0 and args.experiment_num == 61:
            diff = np.argsort(np.abs(reward_exp - reward_mean))
            # std = np.sqrt(reward_var[diff])
            plt.plot(np.arange(args.batch_size), reward_exp[diff], '.')
            plt.plot(np.arange(args.batch_size), reward_mean[diff], '.')
            # plt.plot(np.arange(args.batch_size), reward_mean[diff] - std, '.')
            # plt.plot(np.arange(args.batch_size), reward_mean[diff] + std, '.')
            plt.show()
            # plot_variational_variance(args, env_sampler, trainer, trainer.reward_model)
            # plot_variational_variance(args, env_sampler, trainer, trainer.reward_model)
            # plot_variational_variance(args, env_sampler, trainer, trainer.reward_model)
            plt.plot(test_losses)
            plt.plot(test_squared)
            plt.show()