import gym
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("../../rl_zoo3/")
from action_space_transform_wrappers import ActionRedundancyWrapper
from stable_baselines3.common.buffers import ReplayBuffer, RolloutBuffer, DictRolloutBuffer
import torch
from scipy.special import betainc, betaincinv
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import torch as th
from gym import spaces
from torch import nn

from stable_baselines3.common.policies import BasePolicy#, ContinuousCritic

from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    CombinedExtractor,
    FlattenExtractor,
    NatureCNN,
    create_mlp,
    get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
import numpy as np
from torch.nn import functional as F

from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from torch.utils.tensorboard import SummaryWriter


class EmbeddingInvNetwork(nn.Module):
    def __init__(self, obs_shape: int, action_shape: int, action_bins: int, net_arch: List[int]=[400, 300],) -> None:

        super(EmbeddingInvNetwork, self).__init__()
        self.obs_shape = obs_shape
        self.action_shape = action_shape
        self.inverse_net_list = nn.ModuleList()
        for _ in range(self.action_shape):
            inv_net = create_mlp(input_dim = obs_shape*2 + action_shape, output_dim = action_bins, net_arch = net_arch)
            inverse_net = nn.Sequential(*inv_net)
            self.inverse_net_list.append(inverse_net)

    def forward(self, state: th.Tensor, next_state: th.Tensor, real_action: th.Tensor):

        encode_state = state.to(th.float)
        encode_next_state = next_state.to(th.float)
        real_action = real_action.to(th.float)

        masked_actions = []

        for i in range(self.action_shape):

            tensor_copy = real_action.clone()

            tensor_copy[:, i] = 0

            masked_actions.append(tensor_copy)

        predict_actions = []

        for i in range(self.action_shape):
            # get pred action distribution
            concat_state = th.cat((encode_state, encode_next_state, masked_actions[i]), 1)

            x = self.inverse_net_list[i](concat_state)
            predict_actions.append(x)

        predict_actions = th.stack(predict_actions,dim=1)
        return predict_actions


def test_inverse_dist(replay_buffer, inverse_net, writer, epoch, actuator_id = 0):
    # test inv
    all_data = replay_buffer._get_samples(np.arange(replay_buffer.buffer_size))
    state = all_data.observations.to(device)
    action = all_data.actions.to(device)
    s_prime = all_data.next_observations.to(device)

    predict = inverse_net(state, s_prime, action)

    target = action

    discrete_target = th.bucketize(target, bins) - 1

    test_predict = predict[:,actuator_id,:]
    test_target = target[:,actuator_id]
    test_discrete_target = discrete_target[:,actuator_id]

    bias =  F.cross_entropy(test_predict, test_discrete_target, reduction='none')


    sorted_target, indices = torch.sort(test_target, dim=0)


    sorted_bias = bias[indices]


    sorted_target = sorted_target.squeeze()
    sorted_bias = sorted_bias.squeeze()


    fig1 = plt.figure(figsize=(10, 6))
    plt.scatter(sorted_target.detach().cpu().numpy(), sorted_bias.detach().cpu().numpy(), s=1, alpha=0.6)
    plt.title('Scatter Plot of Bias vs. Sorted Target')
    plt.xlabel('Sorted Target')
    plt.ylabel('Bias')
    plt.grid(True)

    writer.add_figure('Scatter', fig1, global_step=epoch)
    plt.close(fig1)



    interval = 0.1
    min_value = -1
    max_value = 1
    bins_2 = np.arange(min_value, max_value, interval)

    bin_indices = np.digitize(sorted_target.detach().cpu().numpy(), bins_2)
    bin_means = [sorted_bias[bin_indices == i].mean().item() if len(sorted_bias[bin_indices == i]) > 0 else 0 for i in
     range(1, len(bins_2) + 1)]


    fig2 = plt.figure(figsize=(10, 6))
    plt.bar(bins_2 + 0.05, bin_means, width=interval, edgecolor='black', alpha=0.7)
    plt.title('Mean of Bias in Each Interval of Sorted Target')
    plt.xlabel('Sorted Target Interval')
    plt.ylabel('Mean Bias')
    plt.grid(True)
    x_ticks = np.arange(-1.2, 1.2, 0.1)
    plt.xticks(x_ticks)
    writer.add_figure('Hist', fig2, global_step=epoch)
    plt.close(fig2)


if __name__ == "__main__":
    if th.cuda.is_available():
        device = th.device("cuda")
        print("CUDA is available. Using GPU.")
    else:
        device = th.device("cpu")
        print("CUDA is not available. Using CPU.")

    action_bins = 20
    bins = th.linspace(-1, 1, steps=action_bins + 1).to(device)
    path = ""
    root_path = ""
    replay_buffer = load_from_pkl(root_path + path, verbose=2)

    obs_shape = replay_buffer.obs_shape
    action_dim = replay_buffer.action_dim

    inverse_net = EmbeddingInvNetwork(obs_shape=obs_shape[0], action_shape=action_dim, action_bins=action_bins).to(device)
    inverse_optimizer = th.optim.Adam(inverse_net.parameters(), lr=1e-3)
    inv_batch_size = 64
    inv_gradient_steps = 100000

    import time
    time_flag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())

    exp_name = path.split("_2024")[0]
    exp_path = root_path + 'runs/'+exp_name+"_"+time_flag
    writer = SummaryWriter(exp_path)
    log_interval = 100

    fig_log_interval = 10000

    save_model_interval = 50000

    # Step 2.2: Train inv network
    inv_losses = []
    for i in range(inv_gradient_steps):
        inv_replay_data = replay_buffer.sample(inv_batch_size)
        state = inv_replay_data.observations.to(device)
        action = inv_replay_data.actions.to(device)
        s_prime = inv_replay_data.next_observations.to(device)

        predict = inverse_net(state, s_prime, action)

        target = action

        discrete_target = th.bucketize(target, bins) - 1

        criterion = nn.CrossEntropyLoss()

        predict = predict.view(-1,action_bins)
        discrete_target = discrete_target.view(-1)

        inv_loss = criterion(predict, discrete_target)
        inv_losses.append(inv_loss.item())


        inverse_optimizer.zero_grad()
        inv_loss.backward()
        inverse_optimizer.step()

        if i % log_interval == 0:
            writer.add_scalar('Training Loss', inv_loss.item(), i)
            print("Iter: " + str(i) + " Inv Loss: " + str(inv_loss.item()))
        if i % fig_log_interval == 0:
            test_inverse_dist(replay_buffer,inverse_net,writer,epoch=i)
        if i % save_model_interval == 0:
            th.save(inverse_net, exp_path +'/model.pth')

    print("Model Path: " + exp_path)
    # test_load_ineverse_net = th.load(exp_path +'/model.pth')