from typing import Dict, List, Optional, Tuple

import time, os
import torch
import numpy as np
import tensorflow as tf
import sys
import umap
import matplotlib.pyplot as plt
sys.path.append('/DiffCRL/continual_tune/')

# from continualworld.sac.replay_buffers import EpisodicMemory
from continualworld.sac.sac import SAC

# /*** import regarding diffusion model ***/
import data_augmentation.diffuser.utils as diffuser_utils

from data_augmentation.diffuser.datasets.sequence import SegmentDataset, ElasticSegmentDataset
from data_augmentation.diffuser.config.default_config import Diffuser_Config

from argparse import Namespace
from datetime import datetime

gpus = tf.config.experimental.list_physical_devices('GPU')

if gpus:
    try:
        # 设置 GPU 显存动态增长
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

        # 在此之后，你的 TensorFlow 代码

    except RuntimeError as e:
        # 在某些系统上，设置显存动态增长可能会触发 RuntimeError
        print(e)
else:
    print("No GPU devices found.")

num_tasks = 5
observation_dim = 44
action_dim = 4

diffuser_config = dict(vars(Diffuser_Config))
diffuser_keys = list(diffuser_config.keys())
for _k in diffuser_keys:
    if _k.startswith('__'):
        del diffuser_config[_k]
diffuser_config = Namespace(**diffuser_config)
diffuser_config.cond_dim = num_tasks
diffuser_config.observation_dim = observation_dim - num_tasks
diffuser_config.action_dim = action_dim
diffuser_config.transition_dim = diffuser_config.observation_dim + diffuser_config.action_dim
diffuser_config.n_tasks = num_tasks

diffuser_config.n_train_steps = 100000
diffuser_config.n_steps_per_epoch = 10000

# diffuser_config.bucket = '/DiffCRL/continual_tune/logs/diffusion_models/1000stp-selfclone'
diffuser_config.bucket = '/DiffCRL/continual_tune/logs/diffusion_models/1000stp-bcfiltercontinual'

# /* --- Define the diffusion model --- */
backbone_config = diffuser_utils.Config(
    diffuser_config.backbone,
    savepath='diffusion_backbone.pkl',
    horizon=diffuser_config.horizon,
    transition_dim=diffuser_config.transition_dim,
    cond_dim=diffuser_config.cond_dim,
    dim_mults=diffuser_config.dim_mults,
    input_condition=diffuser_config.input_condition,
    dim=diffuser_config.dim,
    condition_dropout=diffuser_config.condition_dropout,
    calc_energy=diffuser_config.calc_energy,
    device=diffuser_config.device,
)
diffusion_config = diffuser_utils.Config(
    diffuser_config.diffusion,
    savepath='diffusion_config.pkl',
    horizon=diffuser_config.horizon,
    observation_dim=diffuser_config.observation_dim,
    action_dim=diffuser_config.action_dim,
    n_timesteps=diffuser_config.n_diffusion_steps,
    loss_type=diffuser_config.loss_type,
    clip_denoised=diffuser_config.clip_denoised,
    predict_epsilon=diffuser_config.predict_epsilon,
    ## loss weighting
    action_weight=diffuser_config.action_weight,
    loss_weights=diffuser_config.loss_weights,
    loss_discount=diffuser_config.loss_discount,
    input_condition=diffuser_config.input_condition,
    condition_guidance_w=diffuser_config.condition_guidance_w,
    device=diffuser_config.device,
)
trainer_config = diffuser_utils.Config(
    diffuser_utils.Trainer,
    savepath='trainer_config.pkl',
    train_batch_size=diffuser_config.batch_size,
    train_lr=diffuser_config.learning_rate,
    gradient_accumulate_every=diffuser_config.gradient_accumulate_every,
    ema_decay=diffuser_config.ema_decay,
    sample_freq=diffuser_config.sample_freq,
    save_freq=diffuser_config.save_freq,
    log_freq=diffuser_config.log_freq,
    label_freq=int(diffuser_config.n_train_steps // diffuser_config.n_saves),
    save_parallel=diffuser_config.save_parallel,
    bucket=diffuser_config.bucket,
    n_reference=diffuser_config.n_reference,
    train_device=diffuser_config.device,
    save_checkpoints=diffuser_config.save_checkpoints,
)


diffusion_backbone = backbone_config()
diffusion = diffusion_config(diffusion_backbone)
n_trajectories = 100
task_conds = torch.eye(num_tasks).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

data_direc = "/DiffCRL/continual_tune/bc_data/er_success"
all_data_loaders = []
for task_id in range(num_tasks):
    expert_data_path = os.path.join(data_direc, f"Expert-task{task_id}.pt")
    dataset = ElasticSegmentDataset(n_tasks=num_tasks)
    # Add expert data to the dataset
    expert_data = torch.load(expert_data_path)
    total_action = expert_data['expert_traj'][:, :, :action_dim]
    total_obs = expert_data['expert_traj'][:, :, action_dim:-num_tasks]
    dataset.add_data(total_obs, total_action, task_id)
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=32, num_workers=0, shuffle=True, pin_memory=True
    )
    all_data_loaders.append(data_loader)

test_data_direc = "/DiffCRL/continual_tune/bc_data/er_success2"
all_test_data_loaders = []
for task_id in range(num_tasks):
    expert_data_path = os.path.join(test_data_direc, f"Expert-task{task_id}.pt")
    dataset = ElasticSegmentDataset(n_tasks=num_tasks)
    # Add expert data to the dataset
    expert_data = torch.load(expert_data_path)
    total_action = expert_data['expert_traj'][:, :, :action_dim]
    total_obs = expert_data['expert_traj'][:, :, action_dim:-num_tasks]
    dataset.add_data(total_obs, total_action, task_id)
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=32, num_workers=0, shuffle=True, pin_memory=True
    )
    all_test_data_loaders.append(data_loader)

def generate_samples(task_id, batch_size):
    condition = task_conds[task_id].repeat(batch_size, 1)
    samples, _ = diffusion.conditional_sample(cond_input=condition, horizon=200, verbose=False, return_diffusion=True)
    return samples

time_str = datetime.now().strftime("%Y-%m-%d-%H:%M:%S")

def load_diffusion(trainer, model_path):
    data = torch.load(model_path)
    trainer.step = data['step']
    trainer.model.load_state_dict(data['model'])
    trainer.ema_model.load_state_dict(data['ema'])

def load_diffusion_model(diffusion, model_path):
    data = torch.load(model_path)
    diffusion.load_state_dict(data['model'])

if False:
    expert_data_path = os.path.join(data_direc, f"Expert-task0.pt")
    dataset = ElasticSegmentDataset(n_tasks=num_tasks)
    expert_data = torch.load(expert_data_path)
    total_action = expert_data['expert_traj'][:, :, :action_dim]
    total_obs = expert_data['expert_traj'][:, :, action_dim:-num_tasks]
    dataset.add_data(total_obs, total_action, 0)

    diffuser_trainer = trainer_config(diffusion, dataset)
    diffuser_trainer.bucket = os.path.join(diffuser_config.bucket, f'expert_trained_t0')
    for train_iter in range(int(diffuser_config.n_train_steps // diffuser_config.n_steps_per_epoch)):
        diffuser_trainer.train(diffuser_config.n_steps_per_epoch)
        tot_loss = []
        tot_loss.append(diffuser_trainer.metrics_test(data_loader[0]))
        print('Computed losses:', diffuser_trainer.metrics_test(data_loader[0]))
        print('Test computed losses:', diffuser_trainer.metrics_test(test_data_loader[0]))
        
if False:
    # Load model and generate samples
    total_generated_samples = []
    # load_model
    # model_path = f"/DiffCRL/continual_tune/logs/diffusion_models/1000stp-bcfilter/checkpoint/state_99999.pt"
    # model_path = f"/DiffCRL/continual_tune/logs/diffusion_models/1000stp-ocfilter/checkpoint/state_99999.pt"
    model_path = f"/DiffCRL/continual_tune/logs/diffusion_models/1000stp-nofilter/checkpoint/state_99999.pt"
    data = torch.load(model_path)
    diffusion.load_state_dict(data['model'])
    samples = generate_samples(0, 100).cpu().numpy()
    save_path = "/DiffCRL/continual_tune/bc_data/no_filter_generated/Generated.pt"
    torch.save(samples, save_path)

if False:
    total_generated_samples = []
    for task_idx in range(5):
        # load_model
        model_path = f"/DiffCRL/continual_tune/logs/diffusion_models/1000stp-continual/task-{task_idx}/checkpoint/state_99999.pt"
        data = torch.load(model_path)
        diffusion.load_state_dict(data['model'])
        # generate samples
        samples = generate_samples(0, 100).cpu().numpy()
        total_generated_samples.append(samples)
        print(f'Successfully load model-{task_idx} and generate samples!!!')

    expert_data_path = os.path.join(data_direc, f"Expert-task0.pt")
    expert_data = torch.load(expert_data_path)['expert_traj'][..., :-num_tasks]
    total_generated_samples.append(expert_data)

    vectors_array = np.array(total_generated_samples).reshape(-1, expert_data.shape[-1])

    # 使用UMAP进行降维
    reducer = umap.UMAP(n_components=2)  # 设定降维后的维度为2
    embedding = reducer.fit_transform(vectors_array)

    # 绘制二维散点图，每个向量数组使用不同颜色
    plt.figure(figsize=(8, 6))
    num_vectors = len(total_generated_samples)
    # colors = plt.cm.jet(np.linspace(0, 1, num_vectors))  # 生成颜色映射

    for i in range(num_vectors):
        start_index = i * 100  # 每个数组包含100个向量
        end_index = start_index + 100
        labeling = f'Generated-{i}' if i != num_vectors - 1 else 'Expert'
        plt.scatter(embedding[start_index:end_index, 0], embedding[start_index:end_index, 1], s=10, label=labeling)

    plt.title('UMAP Visualization of Vectors with Different Colors')
    plt.xlabel('UMAP Dimension 1')
    plt.ylabel('UMAP Dimension 2')
    plt.colorbar()
    plt.tight_layout()
    plt.savefig('Output.pdf')

### Load model and conduct bc learning
if False:
    model_path = "/DiffCRL/continual_tune/logs/diffusion_models/1000stp-continual/task-0/checkpoint/state_99999.pt"
    data = torch.load(model_path)
    diffusion.load_state_dict(data['model'])

    # save_path = "/DiffCRL/continual_tune/bc_data/generated/1000stp-pure-task0.pt"
    save_direc = "/DiffCRL/continual_tune/bc_data/er_generated/"
    for task_idx in range(1):
        samples = generate_samples(task_idx, 100)
        save_path = os.path.join(save_direc, f'Generated-task{task_idx}.pt')
        torch.save(samples, save_path)

### Sequential training

# load diffusion model
if False:
    model_path = "/DiffCRL/continual_tune/logs/diffusion_models/1000stp-continual/task-0/checkpoint/state_99999.pt"
    load_diffusion_model(diffusion, model_path)
    print('------ success load!!!')


from continualworld.sac.models import MlpActor as TFMlpActor
from continualworld.tasks import TASK_SEQS
from continualworld.envs import get_cl_env, get_single_env, get_task_name
from continualworld.utils.utils import get_activation_from_str as tf_get_activation_from_str
from continualworld.utils.utils_torch import get_activation_from_str

activation = 'lrelu'
alpha = 'auto'
batch_size = 128
gamma = 0.99
hidden_sizes = [256, 256, 256, 256]
log_every = 2000
lr = 0.001
replay_size = 1000000
seed = 0
steps = 1000000
target_output_std = 0.089
task = 'hammer-v2'
update_after = 1000
update_every = 50
use_layer_norm = True

actor_kwargs = dict(
    hidden_sizes = hidden_sizes,
    activation = get_activation_from_str(activation),
    use_layer_norm=use_layer_norm,
)
tasks = 'CW5'
tasks = TASK_SEQS[tasks]
env = get_cl_env(tasks, 100000000)
test_env = get_cl_env(tasks, 100000000)
env.cur_seq_idx = 0
test_env.cur_seq_idx = 0

actor_kwargs["action_space"] = env.action_space
actor_kwargs["input_dim"] = env.observation_space.shape[0]

tf_actor_kwargs = {
    'action_space': actor_kwargs['action_space'],
    'input_dim': actor_kwargs['input_dim'],
    'hidden_sizes': [256, 256, 256, 256],
    'use_layer_norm': True,
    'num_heads': 5,
    'hide_task_id': True,
    'activation': tf_get_activation_from_str('lrelu'),
}
print(tf_actor_kwargs)
eval_actor = TFMlpActor(**tf_actor_kwargs)

if True:
    data_direc = "/DiffCRL/continual_tune/bc_data/er_success"
    # for train_stage in range(0, num_tasks):
    for train_stage in range(0, num_tasks):
        expert_data_path = os.path.join(data_direc, f"Expert-task{train_stage}.pt")
        dataset = ElasticSegmentDataset(n_tasks=num_tasks)
        # Add generated old expert data to the dataset
        for task_id in range(train_stage):
            tobs, tacs = [], []
            for _ in range(2):
                data = generate_samples(task_id, n_trajectories)
                fake_acs, fake_obs = data[..., :action_dim], data[..., action_dim:]
                fake_obs = fake_obs.cpu().numpy()
                fake_acs = fake_acs.cpu().numpy()
                tobs.append(fake_obs)
                tacs.append(fake_acs)
            tobs, tacs = np.concatenate(tobs, axis=0), np.concatenate(tacs, axis=0)

            # import umap
            # from sklearn.svm import OneClassSVM
            # reducer = umap.UMAP(n_components=2)
            # embedding = reducer.fit_transform(tobs.reshape(2*n_trajectories, -1))
            
            # model_OCSVM = OneClassSVM()
            # model_OCSVM.fit(embedding)

            # distances = model_OCSVM.decision_function(embedding)
            # # sorted_indices = np.argsort(distances)[-int(n_trajectories*3/2):]
            # sorted_indices = np.argsort(distances)[-n_trajectories:]
            # tobs, tacs = tobs[sorted_indices], tacs[sorted_indices]

            # load actor model
            model_path = f"/DiffCRL/continual_tune/logs/er_clipnorm5e-5/2023_12_20__21_29_59_kptGRF/checkpoints/task{task_id}"
            model_path = os.path.join(model_path, 'actor_finished')
            eval_actor.load_weights(model_path)

            tobs_tensor = tf.convert_to_tensor(tobs.reshape(-1, tobs.shape[-1]), dtype=tf.float32)
            mu, log_std, pi, log_pi = eval_actor(tobs_tensor)
            std = np.exp(log_std)
            likelihood = (1.0 / (np.sqrt(2 * np.pi) * std)) * np.exp(-0.5 * ((tacs.reshape(-1, tacs.shape[-1]) - mu) / std)**2)
            likelihood = likelihood.sum(axis=-1)

            traj_likelihood = likelihood.reshape(-1, 200).sum(axis=-1)
            traj_sorted = np.argsort(traj_likelihood)
            tobs = tobs[traj_sorted[-n_trajectories:]]
            tacs = tacs[traj_sorted[-n_trajectories:]]

            dataset.add_data(tobs, tacs, task_id)

        # Add expert data to the dataset
        expert_data = torch.load(expert_data_path)
        total_action = expert_data['expert_traj'][:, :, :action_dim]
        total_obs = expert_data['expert_traj'][:, :, action_dim:-num_tasks]
        dataset.add_data(total_obs, total_action, train_stage)

        diffuser_trainer = trainer_config(diffusion, dataset)
        diffuser_trainer.bucket = os.path.join(diffuser_config.bucket, f'task-{train_stage}')
        for train_iter in range(int(diffuser_config.n_train_steps // diffuser_config.n_steps_per_epoch)):
            diffuser_trainer.train(diffuser_config.n_steps_per_epoch)
            tot_loss = []
            for data_loader in all_data_loaders:
                tot_loss.append(diffuser_trainer.metrics_test(data_loader))
            print('Computed losses:', tot_loss)
            with open(f'{time_str}_bcfilter.json', 'a+') as f:
                f.write(','.join([str(item) for item in tot_loss]) + '\n')


### Multi-task Training
# dataset = ElasticSegmentDataset(n_tasks=num_tasks)
# num_train_tasks = 3
# for train_stage in range(num_train_tasks):
#     if train_stage == 0:
#         expert_data_path = os.path.join(data_direc, f"Expert-task{train_stage}.pt")
#         expert_data = torch.load(expert_data_path)
#         total_action = expert_data['expert_traj'][:, :, :action_dim]
#         total_obs = expert_data['expert_traj'][:, :, action_dim:-num_tasks]
#         dataset.add_data(total_obs, total_action, train_stage)

# diffuser_trainer = trainer_config(diffusion, dataset)


# # # load diffusion model
# if False:
#     model_path = "/DiffCRL/continual_world/logs/diffusion_models/mt-exp2/checkpoint/state_140000.pt"
#     load_diffusion(diffuser_trainer, model_path)
#     print('------ success load!!!')

# for train_iter in range(int(diffuser_config.n_train_steps // diffuser_config.n_steps_per_epoch)):
#     diffuser_trainer.train(diffuser_config.n_steps_per_epoch)
#     tot_loss, tot_test_loss = [], []
#     for data_loader in all_data_loaders:
#         tot_loss.append(diffuser_trainer.metrics_test(data_loader))
#     for test_data_loader in all_test_data_loaders:
#         tot_test_loss.append(diffuser_trainer.metrics_test(test_data_loader))
#     print('Computed losses:', tot_loss)
#     print('Computed test losses:', tot_test_loss)
#     with open(f'MT-{time_str}.json', 'a+') as f:
#         f.write(','.join([str(item) for item in tot_loss]) + '\n')
#     with open(f'MT-{time_str}-test.json', 'a+') as f:
#         f.write(','.join([str(item) for item in tot_test_loss]) + '\n')