import gym
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("../../rl_zoo3/")
# from action_space_transform_wrappers import ActionRedundancyWrapper
# from stable_baselines3.common.buffers import ReplayBuffer, RolloutBuffer, DictRolloutBuffer
import torch
from scipy.special import betainc, betaincinv
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import torch as th
from gym import spaces
from torch import nn

from stable_baselines3.common.policies import BasePolicy#, ContinuousCritic

from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    CombinedExtractor,
    FlattenExtractor,
    NatureCNN,
    create_mlp,
    get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
import numpy as np
from torch.nn import functional as F

from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from torch.utils.tensorboard import SummaryWriter
# from buffers import ActionProbReplayBuffer


class EmbeddingInvNetwork(nn.Module):
    def __init__(self, obs_shape: int, action_shape: int, action_bins: int, net_arch: List[int]=[400, 300],) -> None:

        super(EmbeddingInvNetwork, self).__init__()
        self.obs_shape = obs_shape
        self.action_shape = action_shape
        self.inverse_net_list = nn.ModuleList()
        for _ in range(self.action_shape):
            inv_net = create_mlp(input_dim = obs_shape*2 + action_shape, output_dim = action_bins, net_arch = net_arch)
            inverse_net = nn.Sequential(*inv_net)
            self.inverse_net_list.append(inverse_net)

    def forward_ith(self, state: th.Tensor, next_state: th.Tensor, real_action: th.Tensor, action_dim: int):

        encode_state = state.to(th.float)
        encode_next_state = next_state.to(th.float)
        real_action = real_action.to(th.float)

        masked_action = real_action.clone()
        masked_action[:, action_dim] = 0

        concat_state = th.cat((encode_state, encode_next_state, masked_action), 1)
        predict_ith_action = self.inverse_net_list[action_dim](concat_state)

        return predict_ith_action

    def forward(self, state: th.Tensor, next_state: th.Tensor, real_action: th.Tensor):

        encode_state = state.to(th.float)
        encode_next_state = next_state.to(th.float)
        real_action = real_action.to(th.float)

        predict_actions = []

        for i in range(self.action_shape):
            predict_ith_action = self.forward_ith(encode_state,encode_next_state,real_action,action_dim=i)
            predict_actions.append(predict_ith_action)

        predict_actions = th.stack(predict_actions,dim=1)
        return predict_actions



def test_inverse_dist(replay_buffer, inverse_net, writer, epoch, actuator_id = 0, device=None,bins=None):
    # test inv
    # all_data = replay_buffer._get_samples(np.arange(replay_buffer.buffer_size))
    all_data = replay_buffer.WeightedSample(None, action_dim=0)
    state = all_data.observations.to(device)
    action = all_data.actions.to(device)
    s_prime = all_data.next_observations.to(device)

    predict = inverse_net(state, s_prime, action)

    target = action

    target = th.clip(target,max=1-1e-5,min=-1+1e-5)

    discrete_target = th.bucketize(target, bins) - 1

    test_predict = predict[:,actuator_id,:]
    test_target = target[:,actuator_id]
    test_discrete_target = discrete_target[:,actuator_id]

    if True:
        ith_counts = th.bincount(test_discrete_target, minlength=len(bins)-1)
        print(ith_counts)

    bias =  F.cross_entropy(test_predict, test_discrete_target, reduction='none')

    sorted_target, indices = torch.sort(test_target, dim=0)

    sorted_bias = bias[indices]

    sorted_target = sorted_target.squeeze()
    sorted_bias = sorted_bias.squeeze()

    fig1 = plt.figure(figsize=(10, 6))
    plt.scatter(sorted_target.detach().cpu().numpy(), sorted_bias.detach().cpu().numpy(), s=1, alpha=0.6)
    plt.title('Scatter Plot of Bias vs. Sorted Target')
    plt.xlabel('Sorted Target')
    plt.ylabel('Bias')
    plt.grid(True)

    writer.add_figure('Scatter actuator'+ str(actuator_id), fig1, global_step=epoch)
    plt.close(fig1)



    interval = 0.1
    min_value = -1
    max_value = 1
    bins_2 = np.arange(min_value, max_value, interval)

    bin_indices = np.digitize(sorted_target.detach().cpu().numpy(), bins_2)
    bin_means = [sorted_bias[bin_indices == i].mean().item() if len(sorted_bias[bin_indices == i]) > 0 else 0 for i in
     range(1, len(bins_2) + 1)]


    fig2 = plt.figure(figsize=(10, 6))
    plt.bar(bins_2 + 0.05, bin_means, width=interval, edgecolor='black', alpha=0.7)
    plt.title('Mean of Bias in Each Interval of Sorted Target')
    plt.xlabel('Sorted Target Interval')
    plt.ylabel('Mean Bias')
    plt.grid(True)
    x_ticks = np.arange(-1.2, 1.2, 0.1)
    plt.xticks(x_ticks)
    writer.add_figure('Hist actuator'+str(actuator_id), fig2, global_step=epoch)
    plt.close(fig2)


if __name__ == "__main__":
    if th.cuda.is_available():
        device = th.device("cuda")
        print("CUDA is available. Using GPU.")
    else:
        device = th.device("cpu")
        print("CUDA is not available. Using CPU.")

    action_bins = 20
    bins = th.linspace(-1, 1, steps=action_bins + 1).to(device)
    path = "HalfCheetah-v3_saturation_L2.5R2.5_100000_2024-09-10-19-02-54"
    root_path = ""
    replay_buffer = load_from_pkl(root_path + path, verbose=2)
    print("Train inv with data from "+ path)
    replay_buffer.sub_indices = None

    obs_shape = replay_buffer.obs_shape
    action_dim = replay_buffer.action_dim

    inverse_net = EmbeddingInvNetwork(obs_shape=obs_shape[0], action_shape=action_dim, action_bins=action_bins).to(device)
    inverse_optimizer = th.optim.Adam(inverse_net.parameters(), lr=1e-3)
    inv_batch_size = 64
    inv_gradient_steps = 100000

    import time
    time_flag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())

    exp_name = path.split("_2024")[0]
    exp_path = root_path + 'runs/'+exp_name+"_"+time_flag
    writer = SummaryWriter(exp_path)
    log_interval = 100

    fig_log_interval = 10000

    save_model_interval = 50000

    # for test
    test_action_dist = False
    if test_action_dist:
        replay_data = replay_buffer.WeightedSample(10000, action_dim=0)
        action = replay_data.actions.to(device)
        for i in range(action_dim):
            data = action[:, i].cpu().numpy()

            test_bins = np.linspace(-1, 1, 21)

            indices = np.digitize(data, test_bins) - 1

            counts = np.bincount(indices, minlength=20)

            print("action_dim:"+str(i))
            print(counts)

    # Step 2.2: Train inv network
    count_every_bin = th.zeros(action_bins,device=device)
    # sample_action_dim = 0
    inv_losses = []
    for i in range(inv_gradient_steps):
        # inv_replay_data = replay_buffer.WeightedSample(inv_batch_size, action_dim=sample_action_dim)
        inv_replay_data = replay_buffer.sample(inv_batch_size)
        state = inv_replay_data.observations.to(device)
        action = inv_replay_data.actions.to(device)
        s_prime = inv_replay_data.next_observations.to(device)

        total_inv_loss = 0
        for sample_action_dim in range(action_dim):
            predict = inverse_net.forward_ith(state, s_prime, action, action_dim=sample_action_dim)
            target = action[:,sample_action_dim]
            target = th.clip(target, max=1 - 1e-5, min=-1 + 1e-5)
            discrete_target = th.bucketize(target, bins) - 1

            predict = predict.view(-1,action_bins)
            discrete_target = discrete_target.view(-1)
            # IS_weight = IS_weight.view(-1)

            if test_action_dist:
                ith_counts = th.bincount(discrete_target, minlength=20)
                count_every_bin = count_every_bin + ith_counts
                print(count_every_bin)

            criterion = nn.CrossEntropyLoss()
            ith_inv_loss = criterion(predict, discrete_target)
            total_inv_loss += ith_inv_loss

        inv_loss = total_inv_loss

        inv_losses.append(inv_loss.item())


        inverse_optimizer.zero_grad()
        inv_loss.backward()
        inverse_optimizer.step()

        if i % log_interval == 0:
            writer.add_scalar('Training Loss', inv_loss.item(), i)
            print("Iter: " + str(i) + " Inv Loss: " + str(inv_loss.item()))
        if i % fig_log_interval == 0:
            for actuator_id in range(action_dim):
                test_inverse_dist(replay_buffer, inverse_net, writer, epoch=i, actuator_id=actuator_id, device=device, bins=bins)
            # test_inverse_dist(replay_buffer, inverse_net, writer, epoch=i, actuator_id=0, device=device, bins=bins)
            # test_inverse_dist(replay_buffer, inverse_net, writer, epoch=i, actuator_id=1, device=device, bins=bins)
        if i % save_model_interval == 0:
            th.save(inverse_net, exp_path +'/model.pth')

    print("Model Path: " + exp_path)
    # test_load_ineverse_net = th.load(exp_path +'/model.pth')