# In gridworld, we have a state and partial observation. Here the state is the observation.

import argparse
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from dataclasses import dataclass, field
from typing import NamedTuple, Optional
import math
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="gym.spaces.box")
import pickle
import tqdm
# limit the number of threads to 128
os.environ["OMP_NUM_THREADS"] = "64"
os.environ["OPENBLAS_NUM_THREADS"] = "64"
os.environ["MKL_NUM_THREADS"] = "64"
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from sklearn.cluster import KMeans
import utils.stat

import optax
import chex
from flax.training.train_state import TrainState
from flax import serialization
import jax
import jax.numpy as jnp
import numpy as np
import d4rl
import gym
import wandb
import h5py
import matplotlib.pyplot as plt

from utils.networks import ScannedRNN, ContinuousActorRNN, DiscreteActorRNN, VAE, EncoderWrapper,VQVAE,VQVAE_gumble_softmax,CAAE, MLP_3_LoRA
from gridworld.env import SingleAgentGridworld, FixedGridworld, ExtraRewardGridworld, MDPGridworld, MDPtakeball
from utils.plot_tools import plot_and_save_curves, plot_and_save_bar, plot_and_save_bars, plot_and_save_heatmap
from utils.load_dataset import Transitions, load, load_env

@dataclass
class TrainConfig:
    # General
    device: str = "cuda"
    device_id: list = field(default_factory=lambda: [0])
    debug: bool = False
    disable_jit: bool = False
    extra_attributes: dict = field(default_factory=dict, init=False)
    make_plots: bool = True
    # Experiment
    alg: str = "VAE+Kmeans"  # Algorithm name
    env: str = "MiniGrid-Reacher"  # Minigrid environment name
    seed: int = 1  # Sets Gym, Jax and Numpy seeds
    max_updates: int = 200  # Number of updates
    n_episodes: int = 5  # How many episodes run during evaluation
    checkpoints_path: Optional[str] = None  # Save path
    batch_size: int = 64  # Batch size for all networks
    load_from_rule_based_dataset: bool = True
    true_k_available: bool = True
    dataset_paths: list[str] = field(default_factory=lambda: [
        "datasets/MiniGrid-Reacher-extra-med/good",
        "datasets/MiniGrid-Reacher-extra-med/bad",
        "datasets/MiniGrid-Reacher-extra-med/med",
    ])
    rule_based_dataset_files: list[str] = field(default_factory=lambda: [
        "datasets/rule_based/MiniGrid-Reacher-MDP/balanced_20000.pkl",
        "datasets/rule_based/MiniGrid-Reacher-MDP/rightfirst_20000.pkl",
        "datasets/rule_based/MiniGrid-Reacher-MDP/downfirst_20000.pkl",
        "datasets/rule_based/MiniGrid-Reacher-MDP/zigzag1_20000.pkl",
        "datasets/rule_based/MiniGrid-Reacher-MDP/zigzag2_20000.pkl",
    ])
    dataset_sizes: list[int] = field(default_factory=lambda: [5, 5, 5])
    extra_reward: float = 10.0
    epsilon: float = 0.1
    # Network
    hidden_dim: int = 64
    policy_r: int = 9
    max_grad_norm: float = 0.5
    learning_rate: float = 1e-3

    adam_eps: float = 1e-8
    # Kmeans
    k_value: int = 5 # Number of clusters
    max_traj_len: int = 20  # Max trajectory length
    normalize: bool = True  # Normalize states
    # Wandb logging
    project: str = "1017VAEKmeans"
    group: str = "PKmeans"
    name: str = ""
    algo: str = "SORL"

    take_ball_target: int = 0

    def __post_init__(self):
        # self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}"
        self.name = f"{self.name}-{self.env}-{self.k_value}"
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)
            
    def __setattr__(self, key, value):
        # Add to extra_attributes if not in dataclass fields
        if key not in self.__dataclass_fields__:
            self.extra_attributes[key] = value
        else:
            super().__setattr__(key, value)
    
    def __getattr__(self, key):
        # Fetch from extra_attributes if not in dataclass fields
        if key in self.extra_attributes:
            return self.extra_attributes[key]
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")

def parse_args_and_update_config(config_class):
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--rule_based_dataset_files",
        nargs='+',  # 
        help="List of rule-based dataset file paths"
    )

    #  dataclass ，
    for field_name, field_info in config_class.__dataclass_fields__.items():
        default = field_info.default
        if isinstance(default, (int, float, str, bool)):
            parser.add_argument(f"--{field_name}", type=type(default), default=default, help=f"Default: {default}")
        elif default is None:
            parser.add_argument(f"--{field_name}", type=str, default=None, help="Default: None")

    # 
    args = parser.parse_args()

    # 
    kwargs = vars(args)
    config = config_class(**kwargs)
    return config

def compute_mean_std(data, eps=1e-3):
    mean = jnp.mean(data, axis=(0, 1))
    std = jnp.std(data, axis=(0, 1)) + eps
    return mean, std

def train(config):
    # Set random seed
    rng = jax.random.PRNGKey(config.seed)
    np.random.seed(config.seed)
    env=load_env(config)
    D4RL_envs = ["halfcheetah-medium-expert-v2","walker2d-medium-expert-v2","hopper-medium-expert-v2","ant-medium-expert-v2","halfcheetah-medium-replay-v2"]
    print(config.env, config.env in D4RL_envs)
    
    if config.env in D4RL_envs:
        config.state_dim = env.observation_space.shape[0]
        config.action_dim = env.action_space.shape[0]
    else:
        config.action_dim = env.action_dim
        obs_shape = env.observation_shape
        config.state_dim = obs_shape.prod()
    
    print(f"State dim: {config.state_dim}, Action dim: {config.action_dim}")
    
    dataset, data_idx = load(config,0.25)
    
    if config.true_k_available:
        true_k = int(jnp.max(data_idx)) + 1
        config.k_value = true_k
    print("Run on config: ", config)
    
    # filter out episodes with short length, which can be easily learned by the network
    traj_lengths = jnp.argmax(dataset.done, axis=1) + 1
    # needed_episode_idx = jnp.where(total_return > 0.6)[0]
    needed_episode_idx = jnp.where(traj_lengths > 3)[0]
    dataset = Transitions(
        dataset.obs[needed_episode_idx],
        dataset.action[needed_episode_idx],
        dataset.reward[needed_episode_idx],
        dataset.done[needed_episode_idx]
    )
    data_idx = data_idx[needed_episode_idx]
    print("Dataset filtered, remaining dataset size: ", dataset.obs.shape[0])
    
    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset.obs, eps=1e-3)
    else:
        state_mean, state_std = 0, 1
    # Shuffle the dataset
    dataset = dataset._replace(obs=(dataset.obs - state_mean) / state_std)

    idx = jax.random.permutation(rng, len(dataset.obs))
    dataset = Transitions(dataset.obs[idx], dataset.action[idx], dataset.reward[idx], dataset.done[idx])
    # data_idx = jnp.concatenate([jnp.zeros(expert_start_idx), jnp.ones(len(dataset.obs) - expert_start_idx)])
    data_idx = data_idx[idx]
    print("Dataset shape: obs ", dataset.obs.shape, " action ", dataset.action.shape, " reward ", dataset.reward.shape, " done ", dataset.done.shape)
    
    DiscreteEnvNames = ["MiniGrid-Reacher", "MiniGrid-Binary-Reacher", "MiniGrid-Reacher-noisy", "MiniGrid-Reacher-extra-good", "MiniGrid-Reacher-extra-bad", "MiniGrid-Reacher-extra-med", "MiniGrid-Reacher-MDP", "MDPtakeball", "MDPtakeball-hard"]
    
    model=MLP_3_LoRA(config.policy_r, config.action_dim, config.hidden_dim, config.k_value, config.env in DiscreteEnvNames)
    
    # Initialize model and optimizer
    init_x = jnp.zeros((2, 1, config.state_dim))
    init_id = jnp.zeros((1, config.k_value))
    rng, init_rng = jax.random.split(rng)
    network_params = model.init(init_rng, init_x, init_id)
    
    # Initialize optimizer
    tx = optax.chain(
        optax.clip_by_global_norm(config.max_grad_norm),
        optax.adam(learning_rate=config.learning_rate, eps=config.adam_eps),
    )
    train_state = TrainState.create(
        apply_fn=model.apply,
        params=network_params,
        tx=tx,
    )
    
    # obs (batch_size, max_traj_len, state_dim), act (batch_size, max_traj_len, action_dim), done (batch_size, max_traj_len), id (batch_size, k_value)
    # type=0 : log prod
    # type=1 : sum
    @jax.checkpoint
    @jax.jit
    def trace_prob(params, obs, act, done, id, ty):
        pi = model.apply(params, obs, id)
        # print(pi.shape, act.shape)
        prob = pi.log_prob(act) # (batch_size, max_traj_len)
        done_mask = jnp.cumprod(1 - done.astype(jnp.int32), axis=1)
        prob = jax.lax.cond(ty==0, lambda x:jnp.sum(x, axis=1), lambda x:jnp.sum(jnp.exp(x), axis=1),done_mask * prob)
        return prob

    def p_mat(params, obs, act, done,ty):
        p=[]
        # print(obs.shape, act.shape, done.shape)
        padding_size=config.batch_size-obs.shape[0]%config.batch_size
        pad_obs = jnp.pad(obs, ((0, padding_size), (0, 0), (0, 0)), mode='constant', constant_values=0)
        if config.env in DiscreteEnvNames:
            pad_act = jnp.pad(act, ((0, padding_size), (0, 0)), mode='constant', constant_values=0)
        else:
            pad_act = jnp.pad(act, ((0, padding_size), (0, 0), (0, 0)), mode='constant', constant_values=0)
        pad_done = jnp.pad(done, ((0, padding_size), (0, 0)), mode='constant', constant_values=0)

        # one_hots = jnp.eye(config.k_value)[:, None, :]
        # batched_obs = pad_obs.reshape(-1, config.batch_size, *obs.shape[1:])
        # batched_act = pad_act.reshape(-1, config.batch_size, *act.shape[1:])
        # batched_done = pad_done.reshape(-1, config.batch_size, *done.shape[1:])
        # def process_batch(obs,act,done):
        #     prob = jax.vmap(lambda oh: trace_prob(
        #         params, 
        #         obs, 
        #         act, 
        #         done, 
        #         oh
        #     ))(one_hots)  # (k_value, batch_size)
        #     return prob.T
        # p=jax.vmap(process_batch)(batched_obs, batched_act, batched_done) # (batch_size, k_value, batch_size)

        for i in range(pad_obs.shape[0]//config.batch_size):
            prob=[]
            for j in range(config.k_value):
                prob.append(trace_prob(params, pad_obs[i*config.batch_size:(i+1)*config.batch_size], pad_act[i*config.batch_size:(i+1)*config.batch_size], pad_done[i*config.batch_size:(i+1)*config.batch_size], jnp.eye(config.k_value)[None,j,:],ty))
            prob = jnp.stack(prob, axis=0) # (k_value, batch_size)
            p.append(prob.transpose((1,0)))

        return jnp.concatenate(p, axis=0)[:obs.shape[0]]
    
    loss_history = []
    for i in range(config.max_updates):
        def loss_fn(params, obs, act, done):
            p1 = p_mat(params, obs, act, done, 1)
            p1 = p1/jnp.sum(p1, axis=1, keepdims=True)
            p0 = p_mat(params, obs, act, done, 0)
            # p1 = jax.nn.softmax(p0, axis=1)
            loss = -(p0 * p1).sum(axis=1).mean()
            return loss
        loss_grad = jax.value_and_grad(loss_fn)
        loss, grads = loss_grad(train_state.params, dataset.obs, dataset.action, dataset.done)
        train_state = train_state.apply_gradients(grads=grads)
        loss_history.append(loss)
        print(f"Update {i}, Loss: {loss}")
        wandb.log({"Loss": loss})
        if len(loss_history) > 10 and jnp.abs(loss - jnp.mean(jnp.array(loss_history[-10:]))) < 1e-3:
            break
    print("Training finished after ", i, " updates")

    # p=p_mat(train_state.params, dataset.obs, dataset.action, dataset.done, 0)
    p=p_mat(train_state.params, dataset.obs, dataset.action, dataset.done, 1)
    predicted_labels = jnp.argmax(p, axis=1)    
    print(dataset.obs.shape, dataset.action.shape, dataset.done.shape,p.shape, predicted_labels.shape)
    nmi = normalized_mutual_info_score(data_idx, predicted_labels)
    ari = adjusted_rand_score(data_idx, predicted_labels)
    print(f"NMI: {nmi}, ARI: {ari}")
    wandb.log({"NMI": nmi, "ARI": ari})
    
    dataset_idxs = [jnp.where(predicted_labels == i)[0] for i in range(config.k_value)]
    num_catagories = int(jnp.max(data_idx)) + 1
    heatmap_matrix = np.zeros((config.k_value, num_catagories))
    for i in range(config.k_value):
        for j in range(num_catagories):
            heatmap_matrix[i, j] = jnp.sum(data_idx[dataset_idxs[i]] == j)
    
    plot_save_path = f"results/{config.alg}/{config.env}/{config.k_value}/plots"
    if not os.path.exists(plot_save_path):
        os.makedirs(plot_save_path)
    dataset_categorical_image = plot_and_save_heatmap(
        matrix=heatmap_matrix,
        save_path=os.path.join(plot_save_path, "final_dataset_categorical.png")
    )
    wandb.log({"dataset_categorical": wandb.Image(dataset_categorical_image)})
    # # Plot latent space with clustering results
    # plt.figure(figsize=(8, 6))
    # plt.scatter(latent_representations[:, 0], latent_representations[:, 1], c=labels, cmap='viridis', s=50, alpha=0.7)
    # plt.title("Latent Space Representation with KMeans Clustering")
    # plt.xlabel("Latent Dimension 1")
    # plt.ylabel("Latent Dimension 2")
    # plt.colorbar(label="Cluster")
    # plt.show()
    reward = dataset.reward
    action = dataset.action
    obs = dataset.obs
    # print(utils.stat.g(reward, 0.99))
    # print(utils.stat.u(obs, action))
    for i in range(config.k_value):
        if dataset_idxs[i].shape[0] > 0:
            print(f"Cluster {i}:")
            print(" TQ: ", utils.stat.g(reward[dataset_idxs[i]], 0.99)/utils.stat.g(reward, 0.99))
            print(" SACo: ", utils.stat.u(obs[dataset_idxs[i]], action[dataset_idxs[i]])/utils.stat.u(obs, action))
            print(" Number of samples: ", len(dataset_idxs[i]))




if __name__ == "__main__":
    config = parse_args_and_update_config(TrainConfig)
    if config.debug:
        print(config)
        config.disable_jit = True
    wandb.init(
        project=config.project,
        entity="policy-clustering",#!/bin/bash
        group=config.group,
        name=config.name,
        config=config,
        mode="disabled" if config.debug else "online",
        # mode="disabled",
    )
    
    print("---------------------------------------")
    print(f"Training SORL, Env: {config.env}, Seed: {config.seed}")
    print("---------------------------------------")
    train(config)
