import contextlib
from collections import defaultdict
import os
import datetime
from concurrent import futures
import time
from absl import app, flags
from argparse import ArgumentParser
from ml_collections import config_flags
from accelerate import Accelerator
from accelerate.utils import set_seed, ProjectConfiguration
from accelerate.logging import get_logger
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
import numpy as np
import ddpo_pytorch.rewards
from ddpo_pytorch.stat_tracking import PerPromptStatTracker
# from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob 
from ddpo_pytorch.diffusers_patch.difusco_logprob import difusco_with_logprob,categorical_denoise_step, difusco_with_logprob_mis, categorical_denoise_step_mis, gaussian_denoise_step_mis, _get_variance

from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob

from difusco.models.gnn_encoder import AddLora
# from difusco.models.gnn_encoder import LoraGNN
# from difusco.models.gnn_encoder import AddLoraLayer, LoraLayer

import torch
import wandb
from functools import partial
import tqdm
import tempfile
from PIL import Image
# from config.difsuco_args import arg_parser
from difusco.utils.diffusion_schedulers import InferenceSchedule
from difusco.co_datasets.tsp_graph_dataset import TSPGraphEnvironment
import copy

import copy as cp
import os

tqdm = partial(tqdm.tqdm, dynamic_ncols=True)

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/base_co.py", "Training configuration.")

flags.DEFINE_bool('use_sweep', False, 'Whether to use Weights & Biases sweep for hyperparameter search.')

flags.DEFINE_string('run_name', "", 'Your name')
# flags.DEFINE_string('task', "tsp", 'tsp, mis, pctsp')
flags.DEFINE_string('task', "mis_er", 'tsp, mis_sat, mis_er, pctsp')

flags.DEFINE_integer('task_size', 100,"task size to solve")
flags.DEFINE_integer('task_load_size', 100,"the task size that should be loaded ckpt")

flags.DEFINE_integer('batch_size', 4, "a") # Evaluation이나 처음 sampling을 할 때 사용하는 세팅
flags.DEFINE_integer('gradient_accumulation_steps', 4,"a") #100의 약수
flags.DEFINE_integer('sample_iters', 8, "a")
flags.DEFINE_integer('train_iters', 1,"a")
flags.DEFINE_integer('num_epochs', 10000,"a")

flags.DEFINE_integer('eval_step', -50, "number of steps to evaluate, -1: evaluation for each epoch, 1000: evaluateion whenever for 1000 global steps")

flags.DEFINE_integer('num_workers', 1, "a")
flags.DEFINE_integer('inference_diffusion_steps', 20, "a")
flags.DEFINE_integer('sparse_factor', 50, "a")
flags.DEFINE_integer('two_opt_iterations', 1000, "a")
flags.DEFINE_integer('reward_2opt', 0, "1: with 2opt, 0: without 2opt")
flags.DEFINE_bool('reward_gap', False, "False: do not use reward gap, 1: use reward gap")   
flags.DEFINE_string('resume_from', "", 'resume from checkpoint')
flags.DEFINE_bool("use_activation_checkpoint", False, "use activation checkpoint")
# flags.DEFINE_bool("decode_heatmap", 0, "use heat map decode")
# flags.DEFINE_string('resume_from', " /nfsdata/home/anonymous_user/pcb/rl_finetuning/ddpo_co/logs/MIS_sat_finalmis_sat100_steps10_bs16,st400,ti2,accu16,lr1e-05_last_sd1_2024.05.18_15.16.49/checkpoints/checkpoint_43", 'resume from checkpoint')
# flags.DEFINE_string('resume_from', "/nfsdata/home/anonymous_user/pcb/rl_finetuning/ddpo_co/logs/1000_1e-5tsp1000_steps10_bs16,st400,ti2,accu4,lr1e-05_useenv_last_sd50803_2024.05.10_16.31.29/checkpoints/checkpoint_8", 'resume from checkpoint')

# flags.DEFINE_bool("lora", True, "use lora")
flags.DEFINE_string("mixed_precision", 'fp16', "check mixed precision, no or fp16")
flags.DEFINE_float('learning_rate', 1e-5, "learning rate")
flags.DEFINE_bool('use_env', False, 'generate datasamples during training')
flags.DEFINE_bool('last_train', True, 'train only last layer')
flags.DEFINE_integer("lora_rank", -1, "use lora")
flags.DEFINE_integer("lora_range", 0, "use lora")
flags.DEFINE_integer('num_testset', 500, 'number of testset')
flags.DEFINE_integer('seed', -1, 'seed for random number generator, -1: random seed')
flags.DEFINE_integer('num_loop', 20, 'Loop')
flags.DEFINE_integer('onpolicy', 0, '0: importance sampling, 1: onpolicy')

    # learning rate.
logger = get_logger(__name__)
# torch.use_deterministic_algorithms(True)

def update_config(config,input_var):
    config.unlock() 
    config.run_name = input_var.run_name
    config.task = input_var.task
    config.task_size = input_var.task_size
    config.task_load_size = input_var.task_load_size

    config.batch_size = input_var.batch_size
    # config.decode_heatmap = input_var.decode_heatmap
    config.num_loop = input_var.num_loop
    config.resume_from = input_var.resume_from
    config.num_workers = input_var.num_workers
    config.inference_diffusion_steps = input_var.inference_diffusion_steps
    config.sparse_factor = input_var.sparse_factor
    config.two_opt_iterations = input_var.two_opt_iterations
    config.reward_2opt = input_var.reward_2opt
    config.learning_rate = input_var.learning_rate
    config.use_env = input_var.use_env
    config.seed = input_var.seed
    config.lora_rank = input_var.lora_rank
    config.num_epochs = input_var.num_epochs
    config.last_train = input_var.last_train
    config.num_testset = input_var.num_testset
    config.sample.num_iters_per_epoch = input_var.sample_iters
    config.train.num_inner_epochs = input_var.train_iters
    config.eval_step = input_var.eval_step
    config.reward_gap = input_var.reward_gap
    config.use_sweep = input_var.use_sweep
    config.lora_range = input_var.lora_range
    config.train.gradient_accumulation_steps = input_var.gradient_accumulation_steps
    config.mixed_precision = input_var.mixed_precision
    config.use_activation_checkpoint = input_var.use_activation_checkpoint
    config.onpolicy = input_var.onpolicy
    # total_train_batch_size = config.batch_size * config.train.gradient_accumulation_steps

    config.run_name = f'{input_var.run_name}{input_var.task}{input_var.task_size}_steps{input_var.inference_diffusion_steps}_bs{config.batch_size},st{input_var.sample_iters},ti{input_var.train_iters},accu{config.train.gradient_accumulation_steps},lr{input_var.learning_rate}'

    if config.use_env :
        config.run_name += "_useenv"
    if config.last_train :
        config.run_name += "_last"
    if config.seed>0:
        config.run_name += f"_sd{config.seed}"
    if config.lora_rank > 0:
        config.run_name += f"_lora{config.lora_rank}"
    if config.reward_2opt:
        config.run_name += "_2opt"
    if config.use_sweep:
        config.run_name += "_sweep"

    # if config.ckpt_path=="":
    config.ckpt_path = os.path.join(config.storage_path,"checkpoints",f'{config.task}{config.task_load_size}.ckpt')
    if config.task=='tsp':
        assert config.diffusion_type=="categorical"

        config.training_split = os.path.join(f'data/tsp_custom/',f'{config.task}{config.task_size}_train_{config.data_dist}.txt')
        config.test_split = os.path.join(f'data/tsp_custom/',f'{config.task}{config.task_size}_test_{config.data_dist}.txt')
    
    elif config.task=="mis_sat":
        assert config.diffusion_type=="categorical"
        config.ckpt_path = os.path.join(config.storage_path,"checkpoints",f'{config.task}_{config.diffusion_type}.ckpt')
        config.training_split = os.path.join(f'data/MIS_SAT_train/','*gpickle')
        # config.training_split = os.path.join(f'data/MIS_SAT_test/','*gpickle')

        config.test_split = os.path.join(f'data/MIS_SAT_test/','*gpickle') ## actually validation split

    elif config.task=="mis_er":
        # config.diffusion_type="gaussian"
        config.diffusion_type="categorical"
        if config.diffusion_type=="categorical":
            config.hidden_dim = 128
            print("config.hidden_dim",config.hidden_dim)
        # assert config.diffusion_type=="gaussian"
        config.ckpt_path = os.path.join(config.storage_path,"checkpoints",f'{config.task}_{config.diffusion_type}.ckpt')
        config.training_split = os.path.join(f'data/MIS_ER/er_train','*gpickle')
        config.test_split = os.path.join(f'data/MIS_ER/er_test','*gpickle') ## 
        
        # config.training_split = os.path.join(f'data/MIS_ER/er_test','*gpickle') #
        # config.training_split = os.path.join(f'data/MIS_ER/er_train_small/train_small','*gpickle')
        
    else:
        raise ValueError("wrong task comes")
        
    return config

def train_with_sweep(config=None):
    # basic Accelerate and logging setup
    config = FLAGS.config
    unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")

    config = update_config(config, FLAGS)

    if not config.run_name:
        config.run_name = unique_id
    else:
        config.run_name += "_" + unique_id

    model, model_class, saving_mode = load_model(config)
    if config.last_train or config.lora_rank > 0:
        model.model.requires_grad_(False)
    if config.sparse_factor>0 or config.task=='mis_sat' or config.task=='mis_er':
        sparse =True
    else :
        sparse =False
    # number of timesteps within each trajectory to train on
    num_train_timesteps = int(config.inference_diffusion_steps * config.train.timestep_fraction)
    accelerator_config = ProjectConfiguration(
        project_dir=os.path.join(config.logdir, config.run_name),
        automatic_checkpoint_naming=True,
        total_limit=config.num_checkpoint_limit,
    )
    
    accelerator = Accelerator(
        log_with="wandb",
        mixed_precision=config.mixed_precision,
        project_config=accelerator_config,
        # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
        # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
        # the total number of optimizer steps to accumulate across.
        gradient_accumulation_steps=config.train.gradient_accumulation_steps
        * num_train_timesteps, 
    )


    model.to(accelerator.device)

    inference_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        inference_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        inference_dtype = torch.bfloat16
    if config.last_train :        
        # model.model.time_embed_layers[-1].requires_grad_(True)
        model.to(accelerator.device, dtype=inference_dtype)
        model.model.per_layer_out[-1].to(dtype=torch.float32)
        model.model.layers[-1].to(dtype=torch.float32)
        model.model.out.to(dtype=torch.float32)

        model.model.per_layer_out[-1].requires_grad_(True)
        model.model.layers[-1].requires_grad_(True)
        model.model.out.requires_grad_(True)
        

    if config.lora_rank > 0:
        model = AddLora(model, config)
    # if config.task=='tsp':
    model.model = torch.compile(model.model)

    #accelerator.device = device
    if accelerator.is_main_process:
        accelerator.init_trackers(
            project_name="CO_RLFINETUNE",
            config=config.to_dict(),
            init_kwargs={"wandb": {"name": config.run_name}},
        )
    logger.info(f"\n{config}")

    # set seed (device_specific is very important to get different prompts on different devices)
    if config.seed != -1:   
        set_seed(config.seed, device_specific=True)

    # load scheduler, tokenizer and models.
    #model = model_class.load_from_checkpoint(ckpt_path, param_args=config)

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora model) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if config.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    # Initialize the optimizer
    if config.train.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        model.parameters(),
        lr=config.learning_rate,
        betas=(config.train.adam_beta1, config.train.adam_beta2),
        weight_decay=config.train.adam_weight_decay,
        eps=config.train.adam_epsilon,
    )


    # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
    # more memory
    autocast = accelerator.autocast if config.use_lora else accelerator.autocast
    # autocast = accelerator.autocast

    # Prepare everything with our `accelerator`.
    if config.use_env:
        train_dataloader = TSPGraphEnvironment(config.task_size,sparse_factor=config.sparse_factor)
    else:
        train_dataloader = model.train_dataloader()
    test_dataloader = model.test_dataloader()
    # rewards_mean, rewards_std = - model.test_dataset.cost_mean, model.test_dataset.cost_std
    model_args = copy.deepcopy(model.args)
    model_diffusion = model.diffusion
    
    model, optimizer = accelerator.prepare(model, optimizer)
    # import pdb
    # pdb.set_trace()


    # accelerator.load_state('~/pcb/rl_finetuning/ddpo_co/logs/tsp100_steps7_bs4,st3,ti2_lr0.001_useenv_last_2024.05.08_13.56/checkpoints/checkpoint_1')

    # top_checkpoints = []

    # if config.task=='mis_sat':
    best_score = -np.inf

    # def save_checkpoint(model, optimizer, epoch, cost):
    #     checkpoint_path = "checkpoint_epoch_{}_cost_{:.4f}.pt".format(epoch, cost)
    #     accelerator.save_state(checkpoint_path)
        
    #     # 상위 5개 체크포인트 리스트 업데이트
    #     top_checkpoints.append((checkpoint_path, cost))
    #     top_checkpoints.sort(key=lambda x: x[1])  # 손실(loss)을 기준으로 정렬
    #     if len(top_checkpoints) > 5:
    #         # 상위 5개를 초과하는 체크포인트 삭제
    #         old_checkpoint, _ = top_checkpoints.pop()
    #         os.remove(old_checkpoint)

    # # 체크포인트 로드 함수
    # def load_checkpoint(checkpoint_path):
    #     accelerator.load_state(checkpoint_path)

    # # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
    # # remote server running llava inference.
    # # executor = futures.ThreadPoolExecutor(max_workers=2)

    # Train!
    print('config.sample.num_iters_per_epoch', config.sample.num_iters_per_epoch)
    if config.sample.num_iters_per_epoch == -1:
        config.sample.num_iters_per_epoch = int(len(model.train_dataset)/config.batch_size)
        print("config.sample.num_iters_per_epoch", config.sample.num_iters_per_epoch)

    samples_per_epoch = (
        config.batch_size
        * accelerator.num_processes
        * config.sample.num_iters_per_epoch
    )
    total_train_batch_size = (
        config.batch_size
        * accelerator.num_processes
        * config.train.gradient_accumulation_steps
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num Epochs = {config.num_epochs}")
    logger.info(f"  Sample batch size per device = {config.batch_size}")
    logger.info(f"  Train batch size per device = {config.batch_size}")
    logger.info(
        f"  Gradient Accumulation steps = {config.train.gradient_accumulation_steps}"
    )
    logger.info("")
    logger.info(f"  Total number of samples per epoch = {samples_per_epoch}")
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}"
    )
    logger.info(
        f"  Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}"
    )
    logger.info(f"  Number of inner epochs = {config.train.num_inner_epochs}")

    assert config.batch_size >= config.batch_size
    assert config.batch_size % config.batch_size == 0

    assert samples_per_epoch % total_train_batch_size == 0
    _ = None

    global_step = 0
    epoch = 0

    if config.resume_from:
        logger.info(f"Resuming from {config.resume_from}")
        accelerator.load_state(config.resume_from)
        try:
            epoch = accelerator.state.epoch
        except:
            pass
        if config.resume_from=='MIS_parametersmis_er100_steps10_bs16,st400,ti2,accu4,lr3e-05_last_sd1_2024.05.20_23.29.20' or config.resume_from=='MIS_parametersmis_er100_steps10_bs16,st400,ti2,accu4,lr3e-05_last_sd2_2024.05.20_23.09.13':
            epoch = 42
    #     evaluate(model, model_diffusion, model_args, test_dataloader, _, accelerator,global_step, sparse, use_env=False, inference=True, reward_2opt=config.reward_2opt, print_log=True)
    #     evaluate(model, model_diffusion, model_args, test_dataloader, _, accelerator,global_step, sparse, use_env=False, inference=True, reward_2opt=config.reward_2opt, print_log=True)
    #     evaluate(model, model_diffusion, model_args, test_dataloader, _, accelerator,global_step, sparse, use_env=False, inference=True, reward_2opt=config.reward_2opt, print_log=True)
    #     evaluate(model, model_diffusion, model_args, test_dataloader, _, accelerator,global_step, sparse, use_env=False, inference=True, reward_2opt=config.reward_2opt, print_log=True)

    #     exit(0)
    # else:
    # first_epoch = 0
    first_loop = 0
    num_batch = 0
    time_eval = time.perf_counter()
    print("evaluation stars")
    evaluate(model, model_diffusion, model_args, test_dataloader, _, accelerator,global_step, sparse, use_env=False, inference=True, reward_2opt=config.reward_2opt)

    # # exit(0)
    eval_count = 0
    save_freq = config.num_epochs // config.save_freq

    samples = []
    rewards_log = []
    for loop in range(first_loop, config.num_loop):
        if loop != 0 and loop % save_freq == 0 and accelerator.is_main_process:
            accelerator.state.epoch = epoch
            accelerator.save_state()
        # with torch.no_grad():

        model.eval()
        for i, batch in enumerate(tqdm(
            train_dataloader,
            desc=f"Epoch {epoch}: iterating batches",
            disable=not accelerator.is_local_main_process,
            position=0,
            total=config.sample.num_iters_per_epoch,
        )):
            num_batch += 1
            # with autocast():
            if model_args.task=='tsp':
                latents, edge_index, log_probs, rewards, timesteps  = difusco_with_logprob(
                    model,
                    model_diffusion,
                    model_args,
                    batch,
                    inference=False,
                    use_env = config.use_env,
                    sparse = sparse, 
                    reward_gap=config.reward_gap
                )
                points = batch[1]
                if sparse:
                    points = batch[1].x.view([config.batch_size, -1, 2]).contiguous()
                    edge_index = torch.transpose(edge_index, 0, 1).contiguous()
                    edge_index = edge_index.view(config.batch_size, -1, 2) # [Batch, Num_instance*sparse_factor, 2]
                    edge_index = edge_index - config.task_size*torch.arange(config.batch_size).view([config.batch_size,1,1]).to(edge_index.device) 
                    latents = [latent.view(config.batch_size, -1) for latent in latents]
            
            elif model_args.task=='mis_sat' or model_args.task=='mis_er':
                latents, edge_length, edge_index, log_probs, rewards, timesteps  = difusco_with_logprob_mis(model,
                    model_diffusion,
                    model_args,
                    batch,
                    inference=False,
                    sparse = sparse
                )
                ### 이 때 log_probs = [num_steps, batch_size] list-tensor 구조


            ## latents [inference_steps + 1, node_size_1+node_size_2+...+node_size_batch_size]
            latents = torch.stack(
                latents, dim=1
            )  ## [node_size_1+node_size_2+...+node_size_batch_size, inference_steps + 1]            
            # (batch_size, num_steps + 1, 4, 64, 64)
            log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1)
            timesteps = torch.stack(timesteps, dim=0) # (batch_size, num_steps)  # (batch_size, num_steps)
            if config.task != 'tsp':
                timesteps = timesteps.repeat(config.batch_size,1,1)
                # (batch_size, num_steps)
            
            sample_dict={
                    "timesteps": timesteps.to('cpu'),
                    "latents": latents[
                        :, :-1
                    ].to('cpu'),  # each entry is the latent before timestep t
                    "next_latents": latents[
                        :, 1:
                    ].to('cpu'),  # each entry is the latent after timestep t
                    "log_probs": log_probs.to('cpu'),
                    "rewards": rewards.to('cpu'),
                }
            if config.task=='tsp':
                sample_dict["points"] = points.to('cpu')
            # if sparse:
            #     edge_index = torch.transpose(edge_index, 0, 1).contiguous()
            if sparse:
                if config.task=='tsp':
                    sample_dict['edge_index'] = edge_index.to('cpu')
                elif config.task=='mis_sat' or config.task=='mis_er':
                    point_indicator = batch[2]
                    idx_edge_index = 0
                    idx_point_indicator = 0
                    edge_index_list = []
                    for i in range(config.batch_size):
                        edge_index_list.append((edge_index[:,idx_edge_index:idx_edge_index+edge_length[i]]-idx_point_indicator).to('cpu'))
                        idx_point_indicator += point_indicator[i]
                        idx_edge_index += edge_length[i]
                    sample_dict['edge_index'] = edge_index_list

                    idx_latents = 0
                    latents_list = []
                    for i in range(config.batch_size):
                        latents_list.append(latents[idx_latents:idx_latents+point_indicator[i]][:,:-1].to('cpu'))
                        idx_latents += point_indicator[i]
                    sample_dict['latents'] = latents_list
                    sample_dict['point_indicator'] = point_indicator

                    idx_next_latents = 0
                    next_latents_list = []
                    for i in range(config.batch_size):
                        next_latents_list.append(latents[idx_next_latents:idx_next_latents+point_indicator[i]][:,1:].to('cpu'))
                        idx_next_latents += point_indicator[i]
                    sample_dict['next_latents'] = next_latents_list

                """
                이 부분이 지금 [2, length1+ length2+ length3] 이런 식으로 되어 있는데, 이걸 각각의 sample로 바꾸고, list화 or dictionary 화하고 sorting해야됨
                각 node index도 조정해줘야됨. edge_length는 말그랟로 edge_length고
                node index의 각각의 크기가 있음.

                아마 latent도 수정해야할 듯?
                """
            samples.append(
                sample_dict
            )
        # break
        # time.sleep(0)
            if num_batch >= config.sample.num_iters_per_epoch:
                # wait for all rewards to be computed
                for sample in tqdm(
                    samples,
                    desc="Waiting for rewards",
                    disable=not accelerator.is_local_main_process,
                    position=0,
                ):
                    rewards = sample["rewards"]
                    sample["rewards"] = torch.as_tensor(rewards).view([-1])

                # collate samples into dict where each entry has shape (num_batches_per_epoch * batch_size, ...)
                """
                이건 결국 딴게 아니고, samples에서 total batch로 펴주는 상황
                timesteps = [total_batch_size, num_timesteps, 2]
                edge_index
                """
                if config.task=='tsp':
                    samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}

                elif config.task=='mis_sat' or config.task=='mis_er':
                    # import pdb
                    # pdb.set_trace()
                    samples_temp = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys() if k not in ["edge_index", "latents", "next_latents"]}

                    samples_temp["edge_index"] = []
                    samples_temp["latents"] = []
                    samples_temp["next_latents"] = []
                    
                    for s in samples:
                        samples_temp["edge_index"].extend(s["edge_index"])
                        samples_temp["latents"].extend(s["latents"])
                        samples_temp["next_latents"].extend(s["next_latents"])
                    
                    samples = samples_temp

                # gather rewards across processes
                rewards =samples["rewards"].cpu().numpy()


                rewards_log += rewards.tolist()
                # log rewards and images
                # accelerator.log(
                #     {
                #         "reward": rewards,
                #         "epoch": epoch,
                #         "reward_mean": rewards.mean(),
                #         "reward_std": rewards.std(),
                #     },
                #     step=global_step,
                # )
                # if rewards.mean() > best_score_train and accelerator.is_main_process:
                #     best_score_train = rewards.mean()
                    # accelerator.state.epoch = epoch
                #     accelerator.save_state()

                advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
                # advantages = (rewards - rewards_mean) / (rewards_std)


                # ungather advantages; we only need to keep the entries corresponding to the samples on this process
                samples["advantages"] = (
                    torch.as_tensor(advantages)
                    .reshape(accelerator.num_processes, -1)[accelerator.process_index]
                    .to(accelerator.device)
                )

                del samples["rewards"]

                total_batch_size, num_timesteps, _ = samples["timesteps"].shape
                # print('total_batch_size',total_batch_size)
                assert (
                    total_batch_size
                    == config.batch_size * config.sample.num_iters_per_epoch
                )
                assert num_timesteps == config.inference_diffusion_steps
                #################### TRAINING ####################

                for inner_epoch in range(config.train.num_inner_epochs):
                    # shuffle samples along batch dimension
                    perm = torch.randperm(total_batch_size, device='cpu').tolist()
                    
                    samples = {k: sort_list_by_index(v, perm) if type(v)==type([]) else v[perm] for k, v in samples.items()}
                    """
                    여기서 섞을 때, 이미 edge_index랑 latent가 깔끔하게 분리 되어있어야됨
                    """
                    # shuffle along time dimension independently for each sample
                    if sparse:
                        perms = torch.stack(
                            [
                                torch.arange(num_timesteps, device='cpu')
                                for _ in range(total_batch_size)
                            ]
                        )
                    else:
                        perms = torch.stack(
                            [
                                torch.randperm(num_timesteps, device='cpu')
                                for _ in range(total_batch_size)
                            ]
                        )
                    # samples_cat_perm = dict()
                    if not sparse:
                        for key in ["timesteps", "latents", "next_latents", "log_probs"]:
                            samples[key] = samples[key][
                                torch.arange(total_batch_size, device='cpu')[:, None],
                                perms,
                            ]
                        """
                        여기서 섞을 때, 이미 edge_index랑 latent가 깔끔하게 분리 되어있어야됨
                        """

                    # rebatch for training
                    """
                    여기 부분에서 단순 batch_size로 reshape만 하는게 아니라, edge index, latent, next_latents를 붙여주고, 그 값들을 보정해주는 작업이 필요하다
                    """
                    if config.task=='tsp':
                        samples_batched = {
                            k: v.reshape(-1, config.batch_size, *v.shape[1:])
                            for k, v in samples.items()
                        }
                    elif config.task=='mis_sat' or config.task=='mis_er':
                        samples_batched = {
                            k: v.reshape(-1, config.batch_size, *v.shape[1:]) if not isinstance(v, list) else reshape_list(v, batch_size=config.batch_size)
                            for k, v in samples.items()
                        }
                        # reshape_list


                    # dict of lists -> list of dicts for easier iteration
                    samples_batched = [
                        dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
                    ]

                    # train
                    info = defaultdict(list)
                    for i, sample in tqdm(
                        list(enumerate(samples_batched)),
                        desc=f"Epoch {epoch}.{inner_epoch}: training",
                        position=0,
                        disable=not accelerator.is_local_main_process,
                    ):
                        if sparse:
                            if config.task=='tsp':
                                sample['edge_index'] += config.task_size * torch.arange(config.batch_size).view([config.batch_size,1,1]).to(sample['edge_index'].device)
                                sample['edge_index'] = sample['edge_index'].view(-1, 2).transpose(0,1).contiguous()
                                if config.task=='tsp':
                                    sample['points'] = sample['points'].view(-1, 2).contiguous()
                                sample['latents'] = sample['latents'].transpose(1,2).reshape(-1, config.inference_diffusion_steps).contiguous()
                            elif config.task=='mis_sat' or config.task=='mis_er':
                                edge_index_shift = 0
                                edge_index_list = []
                                for edge_index_i in sample['edge_index']:
                                    index_shift = edge_index_i.max()+1
                                    edge_index_i = cp.deepcopy(edge_index_i) + edge_index_shift
                                    edge_index_list.append(edge_index_i)
                                    edge_index_shift += index_shift
                                    # print('edge_index_shift', edge_index_shift)

                                sample['edge_index'] = torch.cat(edge_index_list, dim=1).contiguous()
                                sample['latents'] = torch.cat(sample['latents'], dim=0).contiguous()
                                sample['next_latents'] = torch.cat(sample['next_latents'], dim=0).contiguous()

                        sample = {k: v.to(accelerator.device) for k, v in sample.items()}
                        """
                        여기 아래도 tsp 전용으로 바꿔야됨.
                        """

                        for j in tqdm(
                            range(num_train_timesteps),
                            desc="Timestep",
                            position=1,
                            leave=False,
                            disable=not accelerator.is_local_main_process,
                        ):
                            model.train()
                            # print_gpu_memory_usage('train',model.model)

                            with accelerator.accumulate(model):
                                with autocast():
                                    # else:
                                    if sparse:
                                        t_start, t_target =  torch.tensor([sample['timesteps'][0,j,0]]), torch.tensor([sample['timesteps'][0,j,1]])
                                    else:
                                        t_start, t_target = sample['timesteps'][:,j,0], sample['timesteps'][:,j,1]
                                    if config.task=='tsp':
                                        _, log_prob, _ = categorical_denoise_step(
                                            model,
                                            model_diffusion,
                                            sample['points'], 
                                            sample["latents"][:, j], 
                                            t_start, 
                                            model.device, 
                                            edge_index= sample['edge_index'] if sparse else None, 
                                            target_t=t_target,
                                            next_xt= sample["next_latents"][:, j],
                                            batch_t=not sparse,
                                            inference=False,
                                            sparse=sparse
                                            )
                                    elif config.task=='mis_sat':
                                        _, log_prob, _ = categorical_denoise_step_mis(
                                            model,
                                            model_diffusion,
                                            model_args,
                                            sample["latents"][:, j], 
                                            t_start, 
                                            model.device, 
                                            edge_index= sample['edge_index'] if sparse else None, 
                                            target_t=t_target,
                                            next_xt= sample["next_latents"][:, j],
                                            batch_t=not sparse,
                                            inference=False,
                                            sparse=sparse, 
                                            point_indicator = sample['point_indicator']
                                            )
                                    elif config.task =='mis_er':
                                        if config.diffusion_type=='gaussian':
                                            _, log_prob, _ = gaussian_denoise_step_mis(
                                            model,
                                            model_diffusion,
                                            sample["latents"][:, j], 
                                            t_start, 
                                            model.device, 
                                            edge_index= sample['edge_index'] if sparse else None, 
                                            target_t=t_target,
                                            next_xt= sample["next_latents"][:, j],
                                            batch_t=not sparse,
                                            inference=False,
                                            sparse=sparse, 
                                            point_indicator = sample['point_indicator']
                                            )
                                        else:
                                            _, log_prob, _ = categorical_denoise_step_mis(
                                            model,
                                            model_diffusion,
                                            model_args,
                                            sample["latents"][:, j], 
                                            t_start, 
                                            model.device, 
                                            edge_index= sample['edge_index'] if sparse else None, 
                                            target_t=t_target,
                                            next_xt= sample["next_latents"][:, j],
                                            batch_t=not sparse,
                                            inference=False,
                                            sparse=sparse, 
                                            point_indicator = sample['point_indicator']
                                            )
                                advantages = torch.clamp(
                                    sample["advantages"],
                                    -config.train.adv_clip_max,
                                    config.train.adv_clip_max,
                                )
                                ratio = torch.exp(log_prob - sample["log_probs"][:, j])
                                unclipped_loss = -advantages * ratio
                                clipped_loss = -advantages * torch.clamp(
                                    ratio,
                                    1.0 - config.train.clip_range,
                                    1.0 + config.train.clip_range,
                                )
                                loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss))

                                # debugging values
                                # John Schulman says that (ratio - 1) - log(ratio) is a better
                                # estimator, but most existing code uses this so...
                                # http://joschu.net/blog/kl-approx.html
                                info["approx_kl"].append(
                                    0.5
                                    * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2)
                                )
                                info["clipfrac"].append(
                                    torch.mean(
                                        (
                                            torch.abs(ratio - 1.0) > config.train.clip_range
                                        ).float()
                                    )
                                )
                                info["loss"].append(loss)

                                # backward pass
                                accelerator.backward(loss)
                                if accelerator.sync_gradients:
                                    accelerator.clip_grad_norm_(
                                        model.parameters(), config.train.max_grad_norm
                                    )
                                optimizer.step()
                                optimizer.zero_grad()

                            if accelerator.sync_gradients:
                                assert (j == num_train_timesteps - 1) and (
                                    i + 1
                                ) % config.train.gradient_accumulation_steps == 0
                                # log training-related stuff
                                info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                                info = accelerator.reduce(info, reduction="mean")
                                info.update({"epoch": epoch, "inner_epoch": inner_epoch})
                                accelerator.log(info, step=global_step)
                                global_step += 1
                                info = defaultdict(list)

                                if config.eval_step > 0 and global_step % config.eval_step == 0:
                                    _ = None
                                    time_eval = time.perf_counter()
                                    evaluate(model,model_diffusion,model_args,test_dataloader, _, accelerator,global_step,sparse,use_env=False,inference=True, reward_2opt=config.reward_2opt)


                                    reward_mean = np.mean(rewards_log)
                                    reward_std = np.std(rewards_log)
                                    print('epoch', epoch, "evaluation done", time.perf_counter()-time_eval, 'reward_mean', reward_mean, 'reward_std', reward_std)
                                    accelerator.log({"reward_mean": reward_mean, "epcoh": epoch, "reward_std": reward_std,},
                                                        step=global_step)
                                    rewards_log = []


                                    if ckpt_cost > best_score and accelerator.is_main_process:
                                        best_score = ckpt_cost
                                        accelerator.state.epoch = epoch
                                        accelerator.save_state()


                    assert accelerator.sync_gradients
                    
                    if config.eval_step<0:
                        # model.eval()
                        eval_count +=1
                        if eval_count % - config.eval_step == 0 and accelerator.is_main_process:
                            
                            _ = None
                            time_eval = time.perf_counter()
                            # print_gpu_memory_usage('eval',model.model)
                            ckpt_cost = evaluate(model,model_diffusion,model_args,test_dataloader, _, accelerator,global_step, sparse, use_env=False, inference=True, reward_2opt=config.reward_2opt)
                            # print("evaluation done", time.perf_counter()-time_eval)

                            reward_mean = np.mean(rewards_log)
                            reward_std = np.std(rewards_log)
                            print('epoch', epoch, "evaluation done", time.perf_counter()-time_eval, 'reward_mean', reward_mean, 'reward_std', reward_std)
                            accelerator.log({"reward_mean": reward_mean, "epcoh": epoch, "reward_std": reward_std,},
                                                step=global_step)
                            rewards_log = []
                    
                            if ckpt_cost > best_score and accelerator.is_main_process:
                                best_score = ckpt_cost
                                accelerator.state.epoch = epoch
                                accelerator.save_state()
                        # save_checkpoint(model, optimizer, epoch, ckpt_cost)

                samples = [] 
                num_batch = 0
                epoch += 1
    

    # best_checkpoint, _ = top_checkpoints[0]
    
import gc
def print_gpu_memory_usage(string, model=None):
    print("GPU Memory Usage:" + string)
    if not model==None:
        total_memory = 0
        
        # 총 메모리 용량 계산
        for name, param in model.named_parameters():
            if param.is_cuda:
                total_memory += param.element_size() * param.nelement()
        
        print(f" Total GPU Memory Usage: {total_memory / 1024 / 1024:.2f} MB")
        
        # 메모리 사이즈가 큰 순서대로 정렬하여 출력
        memory_usage = []
        for name, param in model.named_parameters():
            if param.is_cuda:
                memory = param.element_size() * param.nelement()
                memory_usage.append((name, param.size(), memory))
        
        memory_usage.sort(key=lambda x: x[2], reverse=True)
        
        for name, size, memory in memory_usage:
            print(f" {name} ({size}) - {memory / 1024 / 1024:.2f} MB")
    else:
        for obj in gc.get_objects():
            if torch.is_tensor(obj) and obj.is_cuda:
                print(f" {type(obj).__name__} ({obj.size()}) - {obj.element_size() * obj.nelement() / 1024 / 1024:.2f} MB")

# def sort_by_another_list(list_to_sort, index_list):
## wrong code
#     return [x for _, x in sorted(zip(index_list, list_to_sort))]

def sort_list_by_index(lst, index_list):
    return [lst[i] for i in index_list]
    
def reshape_list(list_to_reshape, batch_size):
    """
    change initial shape of list
    """
    if len(list_to_reshape)%batch_size!=0:
        raise ValueError("batch_size should be divisible by the length of list")
    return [list_to_reshape[i*batch_size:(i+1)*batch_size] for i in range(len(list_to_reshape)//batch_size)]

def evaluate(model,model_diffusion,model_args,test_dataloader, _,accelerator,global_step,sparse,use_env,inference, reward_2opt, print_log=False):
    # all_gt_costs,all_wo_2opt_costs,all_solution_costs = [],[],[]
    all_gt_costs,all_wo_2opt_costs,all_solution_costs, all_best_solved_costs = [],[],[], []
    model.eval()
    time_eval = time.perf_counter()
    with torch.no_grad():
    # with torch.no_grad():
        if model_args.task=='tsp':
            for batch in test_dataloader :
                gt_costs,wo_2opt_costs,solution_costs = difusco_with_logprob(model,model_diffusion,model_args,batch,inference,use_env,sparse)
                all_gt_costs += gt_costs
                all_wo_2opt_costs += wo_2opt_costs
                all_solution_costs += solution_costs
                # print('gt_costs,wo_2opt_costs,solution_costs',gt_costs,wo_2opt_costs,solution_costs)
        elif model_args.task=='mis_sat' or model_args.task=='mis_er':
            for batch in test_dataloader :
                gt_costs, solved_costs_xt, solved_costs_prob = difusco_with_logprob_mis(model, model_diffusion, model_args, batch, inference, sparse, decode_heatmap=model_args.reward_2opt)

                all_wo_2opt_costs += solved_costs_xt
                all_solution_costs += solved_costs_prob


        log_cost = {
                "num of samples": len(all_solution_costs),
                "test_gt_costs": np.mean(all_gt_costs),
                "test_model_costs": np.mean(all_solution_costs),
                "test_model_costs_wo2opt": np.mean(all_wo_2opt_costs),
                }
        
        print('time', time.perf_counter()-time_eval, log_cost)

        if not print_log:
            accelerator.log(
                    log_cost,
                    step=global_step,
                )
        if model_args.task=='tsp':
            if reward_2opt:
                eval_score = np.mean(all_wo_2opt_costs)
            else:
                eval_score = np.mean(all_solution_costs)

        elif model_args.task=='mis_sat' or model_args.task=='mis_er':
            eval_score = np.mean(all_solution_costs)
        
    # elif model_args.task=='mis_sat':
    #     for batch in test_dataloader :
    #         gt_costs, best_solved_costs = difusco_with_logprob_mis(model, model_diffusion, model_args, batch, inference, sparse)
    #         all_gt_costs += gt_costs
    #         all_best_solved_costs += best_solved_costs
    #     accelerator.log(
    #         {
    #             "test_gt_costs": np.mean(all_gt_costs),
    #             "all_best_solved_costs": np.mean(all_best_solved_costs),
    #         },
    #         step=global_step,
    #         )
            

    return eval_score
    
            
def load_model(config):
    """
    Load model and model calss
    """
    from difusco.pl_tsp_model import TSPModel
    from difusco.pl_tsp_classifier_guided_model import TSPGuidedModel
    from difusco.pl_tsp_reward_weighted_model import TSPReward_Weighted_Model
    from difusco.pl_mis_model import MISModel
    from difusco.pl_tsp_model_free import TSPModelFreeGuide
    from difusco.pl_pctsp_model_free import PCTSPModelFreeGuide
    from difusco.pl_pctsp_model import PCTSPModel

    if config.parallel_sampling!=1:
        raise Exception("parallel_sampling is removed!")
    if config.task == "tsp":
        if config.diffusion_type == "gaussian" or config.diffusion_type == "categorical":
            if config.return_condition :
                model_class = TSPModelFreeGuide
            else :
                model_class = TSPModel
            saving_mode = "min"
        elif config.diffusion_type == "classifier":
            model_class = TSPGuidedModel
            saving_mode = "min"
        elif config.diffusion_type == "reward":
            model_class = TSPReward_Weighted_Model
            saving_mode = "min"
        else:
            raise NotImplementedError
        
    elif config.task == "mis_sat" or config.task == "mis_er":
        model_class = MISModel
        saving_mode = "max"
    
    elif config.task == "pctsp":
        if config.diffusion_type == "gaussian" or config.diffusion_type == "categorical":
            if config.return_condition :
                model_class = PCTSPModelFreeGuide
            else :
                model_class = PCTSPModel
            saving_mode = "min"

        else:
            raise NotImplementedError

    else: 
        raise NotImplementedError

    model = model_class.load_from_checkpoint(config.ckpt_path, param_args=config)
    #model = model_class(param_args=config)

    return model, model_class, saving_mode

def main(_):
    if FLAGS.use_sweep:
        # Sweep 구성 설정
        sweep_config = {
            'method': 'random',
            'metric': {
                'name': 'reward_mean',
                'goal': 'maximize'
            },
            'parameters': {
                'train.learning_rate': {
                    'values': [3e-7, 1e-6, 3e-6]
                },
                # 추가적인 하이퍼파라미터를 지정할 수 있습니다.
            }
        }
        print('sweep mode', sweep_config)

        # Sweep 초기화
        sweep_id = wandb.sweep(sweep_config, project='ddpo-pytorch')

        # train_with_sweep 함수를 app.run()으로 실행
        wandb.agent(sweep_id, function=train_with_sweep)
    else:
        # Sweep을 사용하지 않고 훈련 수행
        train_with_sweep()

if __name__ == "__main__":
    
    app.run(main)