import gym
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("../../rl_zoo3/")
# from action_space_transform_wrappers import ActionRedundancyWrapper
from stable_baselines3.common.buffers import ReplayBuffer, RolloutBuffer, DictRolloutBuffer
import torch
from scipy.special import betainc, betaincinv
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import torch as th
from gym import spaces
from torch import nn

from stable_baselines3.common.policies import BasePolicy#, ContinuousCritic

from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    CombinedExtractor,
    FlattenExtractor,
    NatureCNN,
    create_mlp,
    get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
import numpy as np
from torch.nn import functional as F

from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from torch.utils.tensorboard import SummaryWriter
# from .reshape_action_layer import reshape_action, reshape_action_for_tensor
try:
    from .reshape_action_layer import reshape_action, reshape_action_for_tensor
    from .train_inverse import EmbeddingInvNetwork
except ImportError:
    from reshape_action_layer import reshape_action, reshape_action_for_tensor
    from train_inverse import EmbeddingInvNetwork



def save_theta(replay_buffer, inverse_net, writer, epoch, theta_ith, actuator_id = 0, device = None, bins= None):
    # test inv
    upper_bound = replay_buffer.buffer_size if replay_buffer.full else replay_buffer.pos
    all_indices = np.arange(0, upper_bound)

    all_data = replay_buffer._get_samples(all_indices)
    state = all_data.observations.to(device)
    action = all_data.actions.to(device)
    s_prime = all_data.next_observations.to(device)

    predict = inverse_net(state, s_prime, action)
    target = action
    target = th.clip(target, max=1 - 1e-5, min=-1 + 1e-5)

    discrete_target = th.bucketize(target, bins) - 1

    test_predict = predict[:,actuator_id,:]
    test_target = target[:,actuator_id]
    test_discrete_target = discrete_target[:,actuator_id]

    with th.no_grad():
        bias =  F.cross_entropy(test_predict, test_discrete_target, reduction='none')

    bias = bias.cpu().numpy()

    # Step 2: calculate P(e_i|s,s') = P(a_i|s,s') * det(a_i/e_i) for every transition

    # caculate e = h(a) under current theta
    e, slopes = reshape_action_for_tensor(x=test_target, theta=theta_ith, low=-1, high=1)
    slopes = slopes.detach().cpu().numpy()
    epsilon = 1e-8
    log_slopes = np.log(np.abs(slopes) + epsilon)

    # corrected p(e|s,s')
    correct_flag = True
    if correct_flag:
        bias = bias - log_slopes
    else:
        bias = bias  # - log_slopes

    # modified bias because of importance sampling
    IS_weight = (np.abs(slopes) + epsilon)

    IS_bias = bias * IS_weight

    bias = IS_bias

    sorted_target, indices = torch.sort(e, dim=0)

    sorted_bias = bias[indices.cpu().numpy()]

    sorted_target = sorted_target.squeeze()

    fig1 = plt.figure(figsize=(10, 6))
    plt.scatter(sorted_target.detach().cpu().numpy(), sorted_bias, s=1, alpha=0.6)
    plt.title('Scatter Plot of Bias vs. Sorted Target')
    plt.xlabel('Sorted Target')
    plt.ylabel('Bias')
    plt.grid(True)

    # plt.show()
    writer.add_figure('Scatter/actuator_' + str(actuator_id), fig1, global_step=epoch)
    plt.close(fig1)

    interval = 0.1
    min_value = -1
    max_value = 1
    bins_2 = np.arange(min_value, max_value, interval)

    bin_indices = np.digitize(sorted_target.detach().cpu().numpy(), bins_2)
    bin_means = [sorted_bias[bin_indices == i].mean() if len(sorted_bias[bin_indices == i]) > 0 else 0 for i in
                 range(1, len(bins_2) + 1)]
    bin_num = [len(sorted_bias[bin_indices == i]) if len(sorted_bias[bin_indices == i]) > 0 else 0 for i in
               range(1, len(bins_2) + 1)]

    fig2 = plt.figure(figsize=(10, 6))
    plt.bar(bins_2 + 0.05, bin_means, width=interval, edgecolor='black', alpha=0.7)
    plt.title('Mean of Bias in Each Interval of Sorted Target')
    plt.xlabel('Sorted Target Interval')
    plt.ylabel('Mean Bias')
    plt.grid(True)
    x_ticks = np.arange(-1.2, 1.2, 0.1)  # 从 0 到 10，间隔为 2
    plt.xticks(x_ticks)

    # plt.show()
    writer.add_figure('Hist/actuator_' + str(actuator_id), fig2, global_step=epoch)
    plt.close(fig2)



    x = np.linspace(-1, 1, num=1000)
    y = reshape_action(x, theta=theta_ith, low=-1, high=1)
    fig3 = plt.figure(figsize=(10, 6))
    plt.plot(x, y, label='function h')

    plt.xlabel('a')
    plt.ylabel('e')
    theta_list = theta_ith.tolist()
    theta_list = [round(item, 2) for item in theta_list]
    plt.title( str(theta_list[:10]) + "\n" + str(theta_list[10:]))
    plt.grid(True)
    plt.legend()
    writer.add_figure('Function_h/actuator_' + str(actuator_id), fig3, global_step=epoch)
    plt.close(fig3)


if __name__ == "__main__":
    if th.cuda.is_available():
        device = th.device("cuda")
        print("CUDA is available. Using GPU.")
    else:
        device = th.device("cpu")  # 使用 CPU
        print("CUDA is not available. Using CPU.")
    # debug For plot
    debug = True

    # perfect inverse net p(a|s,s')
    exp_path = ""
    inverse_net = th.load(exp_path +'/model.pth')

    # load replay buffer collected where \pi(a|s) is uniform and action mapping layer is eqiuvalent
    buffer_path = ""
    replay_buffer = load_from_pkl(buffer_path, verbose=2)
    action_dim = replay_buffer.action_dim

    import time
    time_flag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    exp_name = exp_path.split("_2024")[0]
    exp_path = exp_name + "_" + "TrainAML" +"_"+time_flag
    writer = SummaryWriter(exp_path)

    # theta initialize: e = h_{\theta}(a)
    # h function is reshape action layer
    action_bins = 20
    bins = th.linspace(-1, 1, steps=action_bins + 1).to(device)

    theta = nn.Parameter(th.zeros(action_dim * action_bins,device=device),requires_grad=True)
    theta_opt = th.optim.Adam([theta], lr=1e-2)
    theta = theta.view(action_dim, action_bins)

    # training process:
    batch_size = 256
    steps = 10000
    # # firstly train only one actuator
    # actuator_id = 0
    correct_flag = True
    correct_const = 0.3


    log_interval = 100
    fig_log_interval = steps // 50
    save_model_interval = steps // 5


    # Train theta
    losses = []
    for i in range(steps):
        replay_data = replay_buffer.sample(batch_size)
        state = replay_data.observations.to(device)
        action = replay_data.actions.to(device)
        s_prime = replay_data.next_observations.to(device)

        predict = inverse_net(state, s_prime, action)
        target = action
        target = th.clip(target, max=1 - 1e-5, min=-1 + 1e-5)

        discrete_target = th.bucketize(target, bins) - 1

        total_loss = 0

        for actuator_id in range(action_dim):
            # Focus on autuator_id th action
            test_predict = predict[:, actuator_id, :]
            test_target = target[:, actuator_id]
            test_discrete_target = discrete_target[:, actuator_id]
            theta_ith = theta[actuator_id, :]

            # log_pa means log p(a|s,s') , which is a tensor without grad
            with th.no_grad():
                log_pa = - F.cross_entropy(test_predict, test_discrete_target, reduction='none')

            # Step 2: calculate P(e_i|s,s') = P(a_i|s,s') / det(e_i/a_i) for every transition
            # caculate e = h(a) under current theta
            e, slopes = reshape_action_for_tensor(x=test_target, theta=theta_ith, low=-1, high=1)

            # slopes means |de/da|
            # log p(e|s,s') = log [ p(a|s,s') / |de/da| ] = log p(a|s,s') - log |de/da|
            epsilon = 1e-8
            log_slopes = th.log(th.abs(slopes) + epsilon)

            # corrected p(e|s,s')
            if correct_flag:
                log_pe = log_pa - correct_const*log_slopes
            else:
                log_pe = log_pa

            # modified bias because of importance sampling
            IS_weight = th.abs(slopes) + epsilon
            IS_log_pe = log_pe * IS_weight

            IS_mean = IS_log_pe.sum() / IS_weight.sum()

            ith_loss = -IS_mean

            total_loss += ith_loss

        loss = total_loss
        theta_opt.zero_grad()
        loss.backward()
        theta_opt.step()


        if i % log_interval == 0:
            writer.add_scalar('Training Loss', loss.item(), i)
            print("Iter: " + str(i) + " Inv Loss: " + str(loss.item()))
            if correct_flag:
                writer.add_scalar('Correct const', correct_const, i)
            # theta_list = theta_ith.tolist()
            # theta_list = [round(item, 2) for item in theta_list]
            # print("theta:" + str(theta_list))
        if i % fig_log_interval == 0:
            with th.no_grad():
                for actuator_id in range(action_dim):
                    theta_ith = theta[actuator_id, :]
                    save_theta(replay_buffer=replay_buffer, inverse_net=inverse_net,writer=writer,epoch=i,
                               theta_ith=theta_ith, actuator_id=actuator_id, device=device, bins=bins)
        if i % save_model_interval == 0:
            pass

