import gym
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("../../rl_zoo3/")
from action_space_transform_wrappers import ActionRedundancyWrapper
from stable_baselines3.common.buffers import ReplayBuffer, RolloutBuffer, DictRolloutBuffer
import torch
from scipy.special import betainc, betaincinv
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import torch as th
from gym import spaces
from torch import nn

from stable_baselines3.common.policies import BasePolicy#, ContinuousCritic

from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    CombinedExtractor,
    FlattenExtractor,
    NatureCNN,
    create_mlp,
    get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
import numpy as np
from torch.nn import functional as F

from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from torch.utils.tensorboard import SummaryWriter
from train_inverse import EmbeddingInvNetwork
from reshape_action_layer import reshape_action, reshape_action_for_tensor

if th.cuda.is_available():
    device = th.device("cuda")
    print("CUDA is available. Using GPU.")
else:
    device = th.device("cpu")
    print("CUDA is not available. Using CPU.")


def save_theta(replay_buffer, inverse_net, writer, epoch, theta_ith, actuator_id = 0):
    # test inv
    all_data = replay_buffer._get_samples(np.arange(replay_buffer.buffer_size))
    state = all_data.observations.to(device)
    action = all_data.actions.to(device)
    s_prime = all_data.next_observations.to(device)

    predict = inverse_net(state, s_prime, action)
    target = action
    discrete_target = th.bucketize(target, bins) - 1

    test_predict = predict[:,actuator_id,:]
    test_target = target[:,actuator_id]
    test_discrete_target = discrete_target[:,actuator_id]

    with th.no_grad():
        bias =  F.cross_entropy(test_predict, test_discrete_target, reduction='none')

    bias = bias.cpu().numpy()

    # Step 2: calculate P(e_i|s,s') = P(a_i|s,s') * det(a_i/e_i) for every transition

    # caculate e = h(a) under current theta
    e, slopes = reshape_action_for_tensor(x=test_target, theta=theta_ith, low=-1, high=1)
    slopes = slopes.detach().cpu().numpy()
    epsilon = 1e-8
    log_slopes = np.log(np.abs(slopes) + epsilon)

    # corrected p(e|s,s')
    correct_flag = True
    if correct_flag:
        bias = bias - log_slopes
    else:
        bias = bias  # - log_slopes

    # modified bias because of importance sampling
    IS_weight = (np.abs(slopes) + epsilon)

    IS_bias = bias * IS_weight

    bias = IS_bias


    sorted_target, indices = torch.sort(e, dim=0)


    sorted_bias = bias[indices.cpu().numpy()]


    sorted_target = sorted_target.squeeze()


    fig1 = plt.figure(figsize=(10, 6))
    plt.scatter(sorted_target.detach().cpu().numpy(), sorted_bias, s=1, alpha=0.6)
    plt.title('Scatter Plot of Bias vs. Sorted Target')
    plt.xlabel('Sorted Target')
    plt.ylabel('Bias')
    plt.grid(True)

    # plt.show()
    writer.add_figure('Scatter', fig1, global_step=epoch)
    plt.close(fig1)


    interval = 0.1
    min_value = -1
    max_value = 1
    bins_2 = np.arange(min_value, max_value, interval)

    bin_indices = np.digitize(sorted_target.detach().cpu().numpy(), bins_2)
    bin_means = [sorted_bias[bin_indices == i].mean() if len(sorted_bias[bin_indices == i]) > 0 else 0 for i in
                 range(1, len(bins_2) + 1)]
    bin_num = [len(sorted_bias[bin_indices == i]) if len(sorted_bias[bin_indices == i]) > 0 else 0 for i in
               range(1, len(bins_2) + 1)]

    fig2 = plt.figure(figsize=(10, 6))
    plt.bar(bins_2 + 0.05, bin_means, width=interval, edgecolor='black', alpha=0.7)
    plt.title('Mean of Bias in Each Interval of Sorted Target')
    plt.xlabel('Sorted Target Interval')
    plt.ylabel('Mean Bias')
    plt.grid(True)
    x_ticks = np.arange(-1.2, 1.2, 0.1)
    plt.xticks(x_ticks)

    # plt.show()
    writer.add_figure('Hist', fig2, global_step=epoch)
    plt.close(fig2)



    x = np.linspace(-1, 1, num=1000)
    y = reshape_action(x, theta=theta_ith, low=-1, high=1)
    fig3 = plt.figure(figsize=(10, 6))
    plt.plot(x, y, label='function h')

    plt.xlabel('a')
    plt.ylabel('e')
    theta_list = theta_ith.tolist()
    theta_list = [round(item, 2) for item in theta_list]
    plt.title( str(theta_list[:10]) + "\n" + str(theta_list[10:]))
    plt.grid(True)
    plt.legend()
    writer.add_figure('Function_h', fig3, global_step=epoch)
    plt.close(fig2)

# debug For plot
debug = True

# perfect inverse net p(a|s,s')
exp_path = None
inverse_net = th.load(exp_path +'/model.pth')

# load replay buffer collected where \pi(a|s) is uniform and action mapping layer is eqiuvalent
buffer_path = None
replay_buffer = load_from_pkl(buffer_path, verbose=2)
action_dim = replay_buffer.action_dim

import time
time_flag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
exp_name = exp_path.split("_2024")[0]
exp_path = exp_name + "_" + "TrainAML" +"_"+time_flag
writer = SummaryWriter(exp_path)

# theta initialize: e = h_{\theta}(a)
# h function is reshape action layer
action_bins = 20
bins = th.linspace(-1, 1, steps=action_bins + 1).to(device)

theta = nn.Parameter(th.zeros(action_dim * action_bins,device=device),requires_grad=True)
theta_opt = th.optim.Adam([theta], lr=1e-2)
theta = theta.view(action_dim, action_bins)

# training process:
batch_size = 256
steps = 10000
actuator_id = 0
correct_flag = True
correct_const = 0.3


log_interval = 100
fig_log_interval = steps // 50
save_model_interval = steps // 5


# Train theta
losses = []
for i in range(steps):
    replay_data = replay_buffer.sample(batch_size)
    state = replay_data.observations.to(device)
    action = replay_data.actions.to(device)
    s_prime = replay_data.next_observations.to(device)

    predict = inverse_net(state, s_prime, action)
    target = action
    discrete_target = th.bucketize(target, bins) - 1

    # Focus on autuator_id th action
    test_predict = predict[:, actuator_id, :]
    test_target = target[:, actuator_id]
    test_discrete_target = discrete_target[:, actuator_id]
    theta_ith = theta[actuator_id, :]

    # log_pa means log p(a|s,s') , which is a tensor without grad
    with th.no_grad():
        log_pa = - F.cross_entropy(test_predict, test_discrete_target, reduction='none')

    # Step 2: calculate P(e_i|s,s') = P(a_i|s,s') / det(e_i/a_i) for every transition
    # caculate e = h(a) under current theta
    e, slopes = reshape_action_for_tensor(x=test_target, theta=theta_ith, low=-1, high=1)

    # slopes means |de/da|
    # log p(e|s,s') = log [ p(a|s,s') / |de/da| ] = log p(a|s,s') - log |de/da|
    epsilon = 1e-8
    log_slopes = th.log(th.abs(slopes) + epsilon)

    # corrected p(e|s,s')
    if correct_flag:
        log_pe = log_pa - correct_const*log_slopes
    else:
        log_pe = log_pa

    # modified bias because of importance sampling
    IS_weight = th.abs(slopes) + epsilon
    IS_log_pe = log_pe * IS_weight

    IS_mean = IS_log_pe.sum() / IS_weight.sum()

    loss = -IS_mean

    theta_opt.zero_grad()
    loss.backward()
    theta_opt.step()


    if i % log_interval == 0:
        writer.add_scalar('Training Loss', loss.item(), i)
        print("Iter: " + str(i) + " Inv Loss: " + str(loss.item()))
        if correct_flag:
            writer.add_scalar('Correct const', correct_const, i)
        # theta_list = theta_ith.tolist()
        # theta_list = [round(item, 2) for item in theta_list]
        # print("theta:" + str(theta_list))
    if i % fig_log_interval == 0:
        with th.no_grad():
            save_theta(replay_buffer=replay_buffer,inverse_net=inverse_net,writer=writer,epoch=i,theta_ith=theta_ith,actuator_id=actuator_id)
    if i % save_model_interval == 0:
        pass

