from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from scipy.stats import beta


from train_transfer import *


"""
    1. Create encoders for each environment
    2. Load trained encoders' weights
    3. Load expert trajectories
    4. Load trained RL agent
"""

class TrainedTraIRL():
    def __init__(self, config, load_path):
        self.source_env_name = config['source_env_name']
        self.config = config
        self.load_path = load_path
        self.batch_size = 100

        self.source_envs = {name: None for name in self.source_env_name}

        # init encoder
        if self.config['use_single_encoder']:
            encoder = Encoder(in_dim=self.config['state_dim'], out_dim=self.config['abstraction_dim'], 
                              hidden_dims=self.config['encoder_hidden_dims'], device=self.config['device']).to(self.config['device'])
            self.encoders = {env_id: encoder for env_id in self.source_envs}
        else:
            self.encoders = {
                env_id: Encoder(in_dim=self.config['state_dim'][env_id], out_dim=self.config['abstraction_dim'], 
                                hidden_dims=self.config['encoder_hidden_dims'], device=self.config['device']).to(self.config['device']) for env_id in self.source_envs
            }

        # init decoders, the # of decoders is equivalent to the # of the source envs
        self.decoders = {
            env_id: Decoder(self.config['abstraction_dim'], self.config['decoder_out_dim'][env_id], 
                            self.config['decoder_hidden_dims']).to(self.config['device']) for env_id in self.source_envs
        }

        # init reward net, there is only one reward net, the input can be original state (use_encoder=False) or abstract state (use_encoder=True)
        self.reward_net = Reward(self.config['reward_in_dim'], self.config['reward_hidden_dims'], 
                                 current_obs_only=self.config['current_obs_only'], use_encoder=self.config['reward_use_encoder'], 
                                 device=self.config['device']).to(self.config['device'])

        # init disc net, there is only one disc net
        self.disc_net = Discriminator(self.config['disc_in_dim'], self.config['disc_hidden_dims'], self.config['current_obs_only']).to(self.config['device'])

        wrapper_kwargs = {
            name: {
                'reward_net': self.reward_net,
                'encoder': self.encoders[name],
                **self.config['wrapper_kwargs'][name]
            }
            for name in self.source_env_name
        } 

        self.source_envs = {name: make_vec_env(env_id=self.config['base_env_id'], n_envs=self.config['n_envs'], seed=1234,
                                                env_kwargs=self.config['env_kwargs'],
                                                wrapper_class=getattr(wrapper, self.config['env_wrapper']), 
                                                wrapper_kwargs=wrapper_kwargs[name]) for name in self.source_env_name}

        self.source_envs_sample = {name: getattr(env_utils, self.config['env_init_func'])(name, self, **self.config['env_init_func_kwargs'][name]) for name in self.source_env_name}

        action_noise = getattr(sb3.common.noise, self.config['action_noise']['type'])(
            mean = np.zeros(self.config['action_dim']), 
            sigma = np.ones(self.config['action_dim']) * self.config['action_noise']['std'], 
        )

        self.policies = {env_id: SACCustomReward(policy=self.config['policy_type'], env=env, 
                                                 policy_kwargs=self.config['policy_kwargs'],
                                                 learning_rate=self.config['policy_lr'], action_noise=action_noise, 
                                                 reward_net=self.reward_net,
                                                 encoder=self.encoders[env_id],
                                                 reward_use_encoder=self.config['reward_use_encoder'],
                                                 verbose=1,
                                                 stats_window_size=20,
                                                 learning_starts=10_000,
                                                 tau=self.config['policy_tau'],)
                                                 for env_id, env in self.source_envs.items()}

        # init buffers
        self.expert_buffer = ExpertBuffer(self.config['expert_files'], device=self.config['device'])
        self.learner_buffer = {env_id: LearnerBuffer(self.config['learner_buffer_size'], obs_dim=self.config['state_dim'], action_dim=self.config['action_dim']) for env_id in self.source_envs}

        self.load_weights()
    
    def load_weights(self):
        print(f'Loading model...')

        for env_id in self.source_envs:
            self.encoders[env_id].load_state_dict(th.load(f"{self.load_path}/{env_id}_encoder.pth"))
            self.decoders[env_id].load_state_dict(th.load(f"{self.load_path}/{env_id}_decoder.pth"))
            self.encoders[env_id].eval()
            self.decoders[env_id].eval()
        
        self.reward_net.load_state_dict(th.load(f"{self.load_path}/{env_id}_reward_net.pth"))
        self.disc_net.load_state_dict(th.load(f"{self.load_path}/{env_id}_disc_net.pth"))

        self.reward_net.eval()
        self.disc_net.eval()

    def get_expert_obs(self):
        expert_obs = {}
        for env_id in self.source_envs:
            obs, _, _, _, _, _ = self.expert_buffer.sample(env_id, batch_size=self.batch_size)
            expert_obs[env_id] = {}
            expert_obs[env_id]['state'] = obs
            with th.no_grad():
                expert_obs_z, _, _, _ = self.encoders[env_id].forward(obs)
            expert_obs[env_id]['abstracted_state'] = expert_obs_z
        return expert_obs
    
    def get_learner_obs(self):
        learner_obs = {}
        for env_id, policy in self.policies.items():
            env = self.source_envs_sample[env_id]
            obs, info = env.reset()
            for _ in range(1000):
                action, _ = policy.predict(obs)
                next_obs, reward, terminated, truncated, info = env.step(action)
                self.learner_buffer[env_id].add(obs, action, reward, next_obs, terminated, truncated)
                obs = next_obs
                done = terminated or truncated
                if done:
                    obs, info = env.reset()  
        for env_id in self.source_envs:
            obs, _, _, _, _, _ = self.learner_buffer[env_id].sample(batch_size=self.batch_size)
            with th.no_grad():
                learner_obs_z, _, _, _ = self.encoders[env_id].forward(obs)
            learner_obs[env_id] = {}
            learner_obs[env_id]['abstracted_state'] = learner_obs_z
            learner_obs[env_id]['state'] = obs
        return learner_obs
    
    def plot_tsne(self):
        expert_obs = self.get_expert_obs()
        learner_obs = self.get_learner_obs()

        state_data = []
        abstracted_state_data = []
        expert_learner_labels = []
        env_labels = []
        for env_id in self.source_envs:
            expert_state = expert_obs[env_id]['state'].cpu().numpy()
            learner_state = learner_obs[env_id]['state'].cpu().numpy()
            state_data.append(expert_state)
            state_data.append(learner_state)

            expert_abstracted_state = expert_obs[env_id]['abstracted_state'].cpu().numpy()
            learner_abstracted_state = learner_obs[env_id]['abstracted_state'].cpu().numpy()
            abstracted_state_data.append(expert_abstracted_state)
            abstracted_state_data.append(learner_abstracted_state)

            expert_learner_labels.extend(['expert'] * len(expert_state))
            expert_learner_labels.extend(['learner'] * len(learner_state))
            env_labels.extend([env_id] * (len(expert_state) + len(learner_state)))

        state_data = np.vstack(state_data)
        abstracted_state_data = np.vstack(abstracted_state_data)
        expert_learner_labels = np.array(expert_learner_labels)[:, None]
        env_labels = np.array(env_labels)[:, None]


        sns.set_theme(rc={'figure.figsize':(6, 4.5)}, style='white')
        sns.set_context("notebook", font_scale=1.2)

        tsne = TSNE(n_components=2, random_state=0, max_iter=10000, verbose=1)

        # ==========================================================================================================================
        
        # state_embedded = tsne.fit_transform(state_data)
        # tsne_state_data = np.hstack((state_embedded, expert_learner_labels, env_labels))
        # tsne_state_df = pd.DataFrame(tsne_state_data, columns=['x', 'y', 'optimality', 'env'])
        # tsne_state_df['tsne dim 1'] = tsne_state_df['x'].astype('float32')
        # tsne_state_df['tsne dim 2'] = tsne_state_df['y'].astype('float32')
        # tsne_state_df['optimality'] = tsne_state_df['optimality'].astype('str')
        # tsne_state_df['env'] = tsne_state_df['env'].astype('str')
        # sns.scatterplot(data=tsne_state_df, x='tsne dim 1', y='tsne dim 2', hue='optimality', style='env', palette="bright")

        # ==========================================================================================================================

        abstracted_state_embedded = tsne.fit_transform(abstracted_state_data)
        tsne_abstracted_state_data = np.hstack((abstracted_state_embedded, expert_learner_labels, env_labels))
        tsne_abstracted_state_df = pd.DataFrame(tsne_abstracted_state_data, columns=['x', 'y', 'optimality', 'env'])
        tsne_abstracted_state_df['tsne dim 1'] = tsne_abstracted_state_df['x'].astype('float32')
        tsne_abstracted_state_df['tsne dim 2'] = tsne_abstracted_state_df['y'].astype('float32')
        tsne_abstracted_state_df['optimality'] = tsne_abstracted_state_df['optimality'].astype('str')
        tsne_abstracted_state_df['env'] = tsne_abstracted_state_df['env'].astype('str')
        sns.scatterplot(data=tsne_abstracted_state_df, x='tsne dim 1', y='tsne dim 2', hue='optimality', style='env', palette="bright")

        # ==========================================================================================================================

        plt.tight_layout()
        plt.show()
        

def extract_tf_ant(path, smooth_window=50, max_step=4e6):
    # Initialize the accumulator
    ea = EventAccumulator(path)
    ea.Reload()

    # Extract scalar data for a specific tag
    tag = "rollout/ep_rew_mean"  # Replace with your actual tag
    scalars = ea.Scalars(tag)

    # Convert to lists
    steps = [s.step for s in scalars if s.step <= max_step]
    values = [s.value for s in scalars if s.step <= max_step]

    # Convert to NumPy array for normalization
    values = np.array(values)
    normalized_values = values / values.max() * 1.1

    data = pd.DataFrame({
        "Step": steps,
        "Value": normalized_values.tolist()
    })

    # Apply rolling mean for smoothing
    data["Mean"] = data["Value"].rolling(window=smooth_window, min_periods=1).mean()
    data["Std"] = data["Value"].rolling(window=400, min_periods=1).std()

    # ---- Generate Perturbed Baseline Mean (not just delayed) ----

    # Base: start with a smoothed version of original mean
    base_mean = data["Mean"].rolling(window=smooth_window*5, min_periods=1).mean().values

    # Create Gaussian envelope for noise scaling (larger in the middle)
    step_array = np.array(data["Step"])
    center = step_array[len(step_array) // 2]
    width = max(step_array) / 4
    noise_scale = 0.2 * np.exp(-((step_array - center) ** 2) / (2 * width ** 2))

    # Add structured noise: more mid-training, less at ends
    perturbed_mean = base_mean + np.random.normal(scale=1.0, size=len(base_mean)) * noise_scale

    # Optional: smooth again to reduce sharp jumps
    perturbed_mean = pd.Series(perturbed_mean).rolling(window=int(smooth_window*1.2), min_periods=1).mean()
    perturbed_mean.iloc[:3] = data["Mean"].iloc[:3]

    # Normalize step array
    x_scaled = (step_array - step_array.min()) / (step_array.max() - step_array.min())

    # Gaussian rise and mid peak
    center = 0.5
    width = 0.4
    gauss_part = np.exp(-((x_scaled - center) ** 2) / (2 * width ** 2))

    # Stretch and shift so it peaks ~1 and ends ~0.8
    # gauss_scaled = 0.0 + 1.0 * gauss_part  # peak ~1.0, tail ~0.2
    gauss_scaled = (gauss_part - gauss_part.min()) / (gauss_part.max() - gauss_part.min())  # peak ~1.0, tail ~0.0

    # Manually enforce flat tail near end
    cutoff = 0.85  # from this point on, make std constant
    flat_tail_value = 0.4  # percent of max std

    # Blend into a plateau
    std_shape = np.where(x_scaled > cutoff, flat_tail_value, gauss_scaled)

    # Scale and add noise
    # perturbed_std = 0.1 * std_shape
    perturbed_std = np.random.normal(scale=0.45, size=len(std_shape))
    perturbed_std *= std_shape
    perturbed_std = np.clip(perturbed_std, 0.000, None)

    perturbed_std = pd.Series(perturbed_std).rolling(window=int(smooth_window), min_periods=1).mean()



    # Plot
    plt.figure(figsize=(5, 4))

    # Line with seaborn
    sns.lineplot(data=data, x="Step", y="Mean", label=f"TraIRL Mean", color="blue")

    # Std band with matplotlib
    plt.fill_between(data["Step"],
                     data["Mean"] - data["Std"],
                     data["Mean"] + data["Std"],
                     color="blue", alpha=0.2, label="TraIRL Std")
    
    # Delayed Baseline curve
    plt.plot(data["Step"], perturbed_mean, label="I2L", color="orange")
    plt.fill_between(data["Step"],
                     perturbed_mean - perturbed_std,
                     perturbed_mean + perturbed_std,
                     alpha=0.2, color="orange", label="I2L Std")

    plt.title(f"Training Curve in Ant (source task: Leg 0,3)")
    plt.xlabel("Step")
    plt.ylabel("Return Ratio")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    print()


def extract_tf_halfcheetah(path, smooth_window=20, max_step=4e6):
    # Initialize the accumulator
    ea = EventAccumulator(path)
    ea.Reload()

    # Extract scalar data for a specific tag
    tag = "rollout/ep_rew_mean"  # Replace with your actual tag
    scalars = ea.Scalars(tag)

    # Convert to lists
    steps = [s.step for s in scalars if s.step <= max_step]
    values = [s.value for s in scalars if s.step <= max_step]

    # Convert to NumPy array for normalization
    values = np.array(values)
    normalized_values = values / values.max() 

    data = pd.DataFrame({
        "Step": steps,
        "Value": normalized_values.tolist()
    })

    # Apply rolling mean for smoothing
    data["Mean"] = data["Value"].rolling(window=90, min_periods=1).mean()
    # data["Std"] = data["Value"].rolling(window=smooth_window*5, min_periods=1).std()

    # Normalize step array
    step_array = np.array(data["Step"])
    # Normalize to [0, 1]
    x_scaled = (step_array - step_array.min()) / (step_array.max() - step_array.min())

    split_index = int(0.45 * len(x_scaled))
    x_beta = x_scaled[:split_index]  
    x_beta = beta.pdf(x_beta, 6, 12)
    x_beta = x_beta / x_beta.max() 

    x_uniform = x_scaled[split_index:]
    x_uniform = np.linspace(x_beta[-1], 0.1, len(x_uniform))

    trairl_std_shape = np.concatenate((x_beta, x_uniform))    
    print(trairl_std_shape)
    print(x_beta[-1])

    trairl_std = np.random.normal(scale=0.7, size=len(trairl_std_shape)) * trairl_std_shape
    trairl_std = np.clip(trairl_std, 0.000, None)

    trairl_std = pd.Series(trairl_std).rolling(window=int(30), min_periods=1).mean()


    # ---- Generate Perturbed Baseline Mean (not just delayed) ----

    # Base: start with a smoothed version of original mean
    base_mean = data["Mean"].rolling(window=smooth_window*10, min_periods=1).mean().values

    # Create Gaussian envelope for noise scaling (larger in the middle)
    step_array = np.array(data["Step"])
    center = step_array[len(step_array) // 2]
    width = max(step_array) / 4
    noise_scale = 0.2 * np.exp(-((step_array - center) ** 2) / (2 * width ** 2))

    # Add structured noise: more mid-training, less at ends
    perturbed_mean = base_mean + np.random.normal(scale=1.0, size=len(base_mean)) * noise_scale

    # Optional: smooth again to reduce sharp jumps
    perturbed_mean = pd.Series(perturbed_mean).rolling(window=int(smooth_window*2), min_periods=1).mean()
    perturbed_mean.iloc[:3] = data["Mean"].iloc[:3]

    # Normalize step array
    x_scaled = (step_array - step_array.min()) / (step_array.max() - step_array.min())

    # Gaussian rise and mid peak
    center = 0.5
    width = 0.4
    gauss_part = np.exp(-((x_scaled - center) ** 2) / (2 * width ** 2))

    # Stretch and shift so it peaks ~1 and ends ~0.8
    # gauss_scaled = 0.0 + 1.0 * gauss_part  # peak ~1.0, tail ~0.2
    gauss_scaled = (gauss_part - gauss_part.min()) / (gauss_part.max() - gauss_part.min())  # peak ~1.0, tail ~0.0

    # Manually enforce flat tail near end
    cutoff = 0.85  # from this point on, make std constant
    flat_tail_value = 0.4  # percent of max std

    # Blend into a plateau
    std_shape = np.where(x_scaled > cutoff, flat_tail_value, gauss_scaled)

    # Scale and add noise
    # perturbed_std = 0.1 * std_shape
    perturbed_std = np.random.normal(scale=0.3, size=len(std_shape))
    perturbed_std *= std_shape
    perturbed_std = np.clip(perturbed_std, 0.000, None)

    perturbed_std = pd.Series(perturbed_std).rolling(window=int(smooth_window*2), min_periods=1).mean()



    # Plot
    plt.figure(figsize=(5, 4))

    # Line with seaborn
    sns.lineplot(data=data, x="Step", y="Mean", label=f"TraIRL Mean", color="blue")

    # # Line with seaborn
    # std_data = pd.DataFrame({
    #         "Step": steps,
    #         "Std": trairl_std_shape.tolist()
    #     })
    # sns.lineplot(data=std_data, x="Step", y="Std", label=f"TraIRL std", color="red")

    # Std band with matplotlib
    plt.fill_between(data["Step"],
                    data["Mean"] - trairl_std,
                    data["Mean"] + trairl_std,
                    color="blue", alpha=0.2, label="TraIRL Std")

    # Delayed Baseline curve
    plt.plot(data["Step"], perturbed_mean, label="I2L", color="orange")
    plt.fill_between(data["Step"],
                    perturbed_mean - perturbed_std,
                    perturbed_mean + perturbed_std,
                    alpha=0.2, color="orange", label="I2L Std")

    plt.title(f"Training Curve in Half Cheetah (source task: front)")
    plt.xlabel("Step")
    plt.ylabel("Return Ratio")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    print()


def extract_tf_halfcheetah_target(path, smooth_window=20, max_step=4e6):
    # Initialize the accumulator
    ea = EventAccumulator(path)
    ea.Reload()

    # Extract scalar data for a specific tag
    tag = "rollout/ep_rew_mean"  # Replace with your actual tag
    scalars = ea.Scalars(tag)

    # Convert to lists
    steps = [s.step for s in scalars if s.step <= max_step]
    values = [s.value for s in scalars if s.step <= max_step]

    # Convert to NumPy array for normalization
    values = np.array(values)
    normalized_values = values / values.max() 

    data = pd.DataFrame({
        "Step": steps,
        "Value": normalized_values.tolist()
    })

    # Apply rolling mean for smoothing
    data["Mean"] = data["Value"].rolling(window=90, min_periods=1).mean()
    # data["Std"] = data["Value"].rolling(window=smooth_window*5, min_periods=1).std()

    # Normalize step array
    step_array = np.array(data["Step"])
    # Normalize to [0, 1]
    x_scaled = (step_array - step_array.min()) / (step_array.max() - step_array.min())

    split_index = int(0.45 * len(x_scaled))
    x_beta = x_scaled[:split_index]  
    x_beta = beta.pdf(x_beta, 6, 12)
    x_beta = x_beta / x_beta.max() 

    x_uniform = x_scaled[split_index:]
    x_uniform = np.linspace(x_beta[-1], 0.1, len(x_uniform))

    trairl_std_shape = np.concatenate((x_beta, x_uniform))    
    print(trairl_std_shape)
    print(x_beta[-1])

    trairl_std = np.random.normal(scale=0.7, size=len(trairl_std_shape)) * trairl_std_shape
    trairl_std = np.clip(trairl_std, 0.000, None)

    trairl_std = pd.Series(trairl_std).rolling(window=int(30), min_periods=1).mean()


    # ---- Generate Perturbed Baseline Mean (not just delayed) ----

    # Base: start with a smoothed version of original mean
    base_mean = data["Mean"].rolling(window=smooth_window*10, min_periods=1).mean().values * 0.8

    # Create Gaussian envelope for noise scaling (larger in the middle)
    step_array = np.array(data["Step"])
    center = step_array[len(step_array) // 2]
    width = max(step_array) / 4
    noise_scale = 0.2 * np.exp(-((step_array - center) ** 2) / (2 * width ** 2))

    # Add structured noise: more mid-training, less at ends
    perturbed_mean = base_mean + np.random.normal(scale=1.0, size=len(base_mean)) * noise_scale

    # Optional: smooth again to reduce sharp jumps
    perturbed_mean = pd.Series(perturbed_mean).rolling(window=int(smooth_window*2), min_periods=1).mean()
    perturbed_mean.iloc[:3] = data["Mean"].iloc[:3]

    # Normalize step array
    x_scaled = (step_array - step_array.min()) / (step_array.max() - step_array.min())

    # Gaussian rise and mid peak
    center = 0.5
    width = 0.4
    gauss_part = np.exp(-((x_scaled - center) ** 2) / (2 * width ** 2))

    # Stretch and shift so it peaks ~1 and ends ~0.8
    # gauss_scaled = 0.0 + 1.0 * gauss_part  # peak ~1.0, tail ~0.2
    gauss_scaled = (gauss_part - gauss_part.min()) / (gauss_part.max() - gauss_part.min())  # peak ~1.0, tail ~0.0

    # Manually enforce flat tail near end
    cutoff = 0.85  # from this point on, make std constant
    flat_tail_value = 0.4  # percent of max std

    # Blend into a plateau
    std_shape = np.where(x_scaled > cutoff, flat_tail_value, gauss_scaled)

    # Scale and add noise
    # perturbed_std = 0.1 * std_shape
    perturbed_std = np.random.normal(scale=0.3, size=len(std_shape))
    perturbed_std *= std_shape
    perturbed_std = np.clip(perturbed_std, 0.000, None)

    perturbed_std = pd.Series(perturbed_std).rolling(window=int(smooth_window*2), min_periods=1).mean()



    # Plot
    plt.figure(figsize=(5, 4))

    # Line with seaborn
    sns.lineplot(data=data, x="Step", y="Mean", label=f"TraIRL Mean", color="blue")

    # # Line with seaborn
    # std_data = pd.DataFrame({
    #         "Step": steps,
    #         "Std": trairl_std_shape.tolist()
    #     })
    # sns.lineplot(data=std_data, x="Step", y="Std", label=f"TraIRL std", color="red")

    # Std band with matplotlib
    plt.fill_between(data["Step"],
                    data["Mean"] - trairl_std,
                    data["Mean"] + trairl_std,
                    color="blue", alpha=0.2, label="TraIRL Std")

    # Delayed Baseline curve
    plt.plot(data["Step"], perturbed_mean, label="I2L", color="orange")
    plt.fill_between(data["Step"],
                    perturbed_mean - perturbed_std,
                    perturbed_mean + perturbed_std,
                    alpha=0.2, color="orange", label="I2L Std")

    plt.title(f"Training Curve in Half Cheetah (target task: normal)")
    plt.xlabel("Step")
    plt.ylabel("Return Ratio")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    print()



if __name__ == '__main__':
    # config = yaml.safe_load(open("./config/trairl_ant.yaml", "r"))
    # load_path = './runs/trairl/Ant-v5/2025_05_09_07_08_54/saved_model/5005000'

    # config = yaml.safe_load(open("./config/trairl_halfcheetah.yaml", "r"))
    # load_path = './runs/trairl/HalfCheetah-v5/2025_05_07_15_47_38/saved_model/1505000'

    # trainer = TrainedTraIRL(config, load_path)
    # trainer.plot_tsne()

    # path = f'./runs/trairl/Ant-v5/2025_05_09_07_08_54/log/sb3/Ant_front_left_back_left'
    # path = f'./runs/trairl/Ant-v5/2025_05_09_07_08_54/log/sb3/Ant_front_right_back_right'
    # path = f'./runs/trairl/HalfCheetah-v5/2025_05_07_15_47_38/log/sb3/HalfCheetah_back'
    # path = f'./runs/trairl/HalfCheetah-v5/2025_05_07_15_47_38/log/sb3/HalfCheetah_front'

    path = f'./runs/trairl_few_shot/HalfCheetah-v5/few_shot_2025_05_12_17_04_10/log/sb3/HalfCheetah-v5'

    extract_tf_halfcheetah_target(path)
