import time
time_init = time.perf_counter()

import contextlib
from collections import defaultdict
import os
import datetime
from concurrent import futures
from absl import app, flags
from argparse import ArgumentParser
from ml_collections import config_flags
from accelerate import Accelerator
from accelerate.utils import set_seed, ProjectConfiguration
from accelerate.logging import get_logger
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
import numpy as np
import ddpo_pytorch.rewards
from ddpo_pytorch.stat_tracking import PerPromptStatTracker
# from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob 
from ddpo_pytorch.diffusers_patch.difusco_logprob import difusco_with_logprob,categorical_denoise_step, test_eval_co
from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob
from torch_geometric.data import DataLoader as GraphDataLoader
from difusco.models.gnn_encoder import AddLora
# from difusco.models.gnn_encoder import LoraGNN
# from difusco.models.gnn_encoder import AddLoraLayer, LoraLayer


import torch
import wandb
from functools import partial
import tqdm
import tempfile
from PIL import Image
# from config.difsuco_args import arg_parser
from difusco.utils.diffusion_schedulers import InferenceSchedule
from difusco.co_datasets.tsp_graph_dataset import TSPGraphEnvironment,TSPLIBGraphDataset
import copy

import os

tqdm = partial(tqdm.tqdm, dynamic_ncols=True)

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file("config", "config/base_co.py", "Training configuration.")

flags.DEFINE_bool('use_sweep', False, 'Whether to use Weights & Biases sweep for hyperparameter search.')
flags.DEFINE_string('run_name', "", 'Your name')
flags.DEFINE_string('task', "tsp", 'tsp, mis, pctsp, tsplib')
flags.DEFINE_integer('task_size', 100,"task size to solve")
flags.DEFINE_integer('task_load_size', 100,"the task size that should be loaded ckpt")

flags.DEFINE_integer('batch_size', 1, "a") # Evaluation이나 처음 sampling을 할 때 사용하는 세팅
flags.DEFINE_integer('gradient_accumulation_steps', 1,"a") #100의 약수
flags.DEFINE_integer('sample_iters', 4,"a")
flags.DEFINE_integer('train_iters', 2,"a")
flags.DEFINE_integer('num_epochs', 4,"a")
flags.DEFINE_integer('print_sl_loss', 0, "a")

flags.DEFINE_integer('eval_step', -1,"number of steps to evaluate, -1: evaluation for each epoch, 1000: evaluateion whenever for 1000 global steps")

flags.DEFINE_integer('num_workers', 1,"a")
flags.DEFINE_integer('inference_diffusion_steps', 20, "a")

flags.DEFINE_string('tsp_decoder', 'cython_merge', "cython_merge or am_decoding or farthest")

flags.DEFINE_integer('sparse_factor', -1, "a")
flags.DEFINE_integer('two_opt_iterations', 1000, "a")
flags.DEFINE_integer('reward_2opt', 1, "1: with 2opt, 0: without 2opt")
flags.DEFINE_bool('reward_gap', False, "False: do not use reward gap, 1: use reward gap")   
flags.DEFINE_string('subopt', "", "name for reading")
# flags.DEFINE_string('resume_from', "", 'resume from checkpoint')

flags.DEFINE_integer("rewrite_steps", 0, "rewrite")
flags.DEFINE_integer("rewrite", 1, "rewrite")
flags.DEFINE_float("rewrite_ratio", 0.25, "save frequency")
flags.DEFINE_integer("inference_steps", 10, "inference steps rewrite")
flags.DEFINE_integer("parallel_sampling", 3,"parallel sampling")


flags.DEFINE_string('resume_from', "", 'resume from checkpoint')
# 
#flags.DEFINE_string('resume_from', "/workspace/pair/squads/diff_nco/icml24_ckpts/tsp100/checkpoint", 'resume from checkpoint')
# flags.DEFINE_string('resume_from', "/nfsdata/home/anonymous_user/pcb/rl_finetuning/ddpo_co/logs/500_to_500_3e-5tsp500_steps10_bs16,st400,ti2,accu4,lr3e-05_useenv_last_sd1_lora1_2024.05.18_18.58.11/checkpoints/checkpoint_35", 'resume from checkpoint')

#flags.DEFINE_string('resume_from', "/nfsdata/home/anonymous_user/GCO/reward-guided-difusco/logs/TSP500_high_lrtsp500_steps5_bs8,st8,ti1,accu1,lr1e-05_tb64_last1_lora2_sd1_2024.09.23_22.49.04/checkpoints/checkpoint_1", 'checkpoint load')
# flags.DEFINE_string('resume_from', "/nfsdata/home/anonymous_user/pcb/rl_finetuning/ddpo_co/logs/1000_1e-5tsp1000_steps10_bs16,st400,ti2,accu4,lr1e-05_useenv_last_sd50803_2024.05.10_16.31.29/checkpoints/checkpoint_8", 'resume from checkpoint')

# flags.DEFINE_bool("lora", True, "use lora")
flags.DEFINE_string("mixed_precision", 'no', "check mixed precision, no or fp16")
flags.DEFINE_float('learning_rate', 1e-5, "learning rate")
flags.DEFINE_bool('use_env', True, 'generate datasamples during training')
flags.DEFINE_bool('last_train', True, 'train only last layer')
flags.DEFINE_integer("lora_rank", -1, "use lora")
flags.DEFINE_integer("lora_range", 1, "use lora")
flags.DEFINE_integer('num_testset', 1280, 'number of testset')
flags.DEFINE_integer('seed', -1, 'seed for random number generator, -1: random seed')
flags.DEFINE_bool('reward_shaping', False, 'use reward shaping')
    # learning rate.
logger = get_logger(__name__)
# torch.use_deterministic_algorithms(True)


def update_config(config,input_var):
    config.unlock() 
    config.run_name = input_var.run_name
    config.task = input_var.task
    config.task_size = input_var.task_size
    config.task_load_size = input_var.task_load_size
    config.tsp_decoder = input_var.tsp_decoder
    config.batch_size = input_var.batch_size
    config.subopt = input_var.subopt
    config.resume_from = input_var.resume_from
    config.num_workers = input_var.num_workers
    config.inference_diffusion_steps = input_var.inference_diffusion_steps
    config.sparse_factor = input_var.sparse_factor
    config.two_opt_iterations = input_var.two_opt_iterations
    config.reward_2opt = input_var.reward_2opt
    config.learning_rate = input_var.learning_rate
    config.use_env = input_var.use_env
    config.seed = input_var.seed
    config.lora_rank = input_var.lora_rank
    config.num_epochs = input_var.num_epochs
    config.last_train = input_var.last_train
    config.num_testset = input_var.num_testset
    config.sample.num_iters_per_epoch = input_var.sample_iters
    config.train.num_inner_epochs = input_var.train_iters
    config.eval_step = input_var.eval_step
    config.reward_gap = input_var.reward_gap
    config.use_sweep = input_var.use_sweep
    config.print_sl_loss = input_var.print_sl_loss
    config.lora_range = input_var.lora_range
    config.train.gradient_accumulation_steps = input_var.gradient_accumulation_steps
    config.mixed_precision = input_var.mixed_precision
    config.reward_shaping = input_var.reward_shaping
    
    config.rewrite_steps = input_var.rewrite_steps
    config.rewrite = input_var.rewrite
    config.rewrite_ratio = input_var.rewrite_ratio
    config.inference_steps = input_var.inference_steps
    config.parallel_sampling = input_var.parallel_sampling

    # total_train_batch_size = config.batch_size * config.train.gradient_accumulation_steps

    config.run_name = f'{input_var.run_name}{input_var.task}{input_var.task_size}_steps{input_var.inference_diffusion_steps}_bs{config.batch_size},st{input_var.sample_iters},ti{input_var.train_iters},accu{config.train.gradient_accumulation_steps},lr{input_var.learning_rate}'

    if config.use_env :
        config.run_name += "_useenv"
    if config.last_train :
        config.run_name += "_last"
    if config.seed>0:
        config.run_name += f"_sd{config.seed}"
    if config.lora_rank > 0:
        config.run_name += f"_lora{config.lora_rank}"
    if config.reward_2opt:
        config.run_name += "_2opt"
    if config.use_sweep:
        config.run_name += "_sweep"

    # config.ckpt_path = os.path.join(config.storage_path,"checkpoints",'lkh1sec_13_24011501.ckpt')
    if config.subopt:
        config.ckpt_path = os.path.join(config.storage_path,"checkpoints",f'{config.subopt}.ckpt')
    else:
        config.ckpt_path = os.path.join(config.storage_path,"checkpoints",f'{config.task.replace("lib","")}{config.task_load_size}.ckpt')
    config.training_split = os.path.join(f'data/tsp_custom/',f'{config.task.replace("lib","")}{config.task_size}_train_{config.data_dist}.txt')
    config.test_split = os.path.join(f'data/tsp_custom/',f'{config.task.replace("lib","")}{config.task_size}_test_{config.data_dist}.txt')
        
    return config

def train_with_sweep(config=None):
    # basic Accelerate and logging setup
    # time_init = time.perf_counter()
    config = FLAGS.config
    unique_id = datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")

    config = update_config(config, FLAGS)

    if not config.run_name:
        config.run_name = unique_id
    else:
        config.run_name += "_" + unique_id

    model, model_class, saving_mode = load_model(config)
    if config.last_train or config.lora_rank > 0:
        model.model.requires_grad_(False)
    
    # model.model = torch.compile(model.model)

    if config.sparse_factor>0 :
        sparse =True
    else :
        sparse =False
    # number of timesteps within each trajectory to train on
    num_train_timesteps = int(config.inference_diffusion_steps * config.train.timestep_fraction)
    accelerator_config = ProjectConfiguration(
        project_dir=os.path.join(config.logdir, config.run_name),
        automatic_checkpoint_naming=True,
        total_limit=config.num_checkpoint_limit,
    )
    
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        project_config=accelerator_config,
        # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
        # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
        # the total number of optimizer steps to accumulate across.
        gradient_accumulation_steps=config.train.gradient_accumulation_steps
        * num_train_timesteps, 
    )
    model.to(accelerator.device)
    model.model.aux = False
    
    
    inference_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        inference_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        inference_dtype = torch.bfloat16
    if config.last_train :        
        # model.model.time_embed_layers[-1].requires_grad_(True)
        model.to(accelerator.device, dtype=inference_dtype)
        model.model.per_layer_out[-1].requires_grad_(True)
        model.model.per_layer_out[-1].to(dtype=torch.float32)
        
        model.model.layers[-1].requires_grad_(True)
        model.model.layers[-1].to(dtype=torch.float32)
        
        model.model.out.requires_grad_(True)
        model.model.out.to(dtype=torch.float32)

    if config.lora_rank > 0:
        model = AddLora(model, config)

    #model.model = torch.compile(model.model)

    #accelerator.device = device
    if accelerator.is_main_process:
        accelerator.init_trackers(
            project_name="CO_RLFINETUNE",
            config=config.to_dict(),
            init_kwargs={"wandb": {"name": config.run_name}},
        )
    logger.info(f"\n{config}")

    # set seed (device_specific is very important to get different prompts on different devices)
    if config.seed != -1:   
        set_seed(config.seed, device_specific=True)

    # load scheduler, tokenizer and models.
    #model = model_class.load_from_checkpoint(ckpt_path, param_args=config)

    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora model) to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if config.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    # Initialize the optimizer
    if config.train.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        model.parameters(),
        lr=config.learning_rate,
        betas=(config.train.adam_beta1, config.train.adam_beta2),
        weight_decay=config.train.adam_weight_decay,
        eps=config.train.adam_epsilon,
    )


    # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
    # more memory
    autocast = accelerator.autocast if config.use_lora else accelerator.autocast
    # autocast = accelerator.autocast

    # Prepare everything with our `accelerator`.
    
    # if config.use_env:
    #     train_dataloader = TSPGraphEnvironment(config.task_size,sparse_factor=config.sparse_factor)
    # else:
    #     train_dataloader = model.train_dataloader()
    
    if config.task == 'tsplib' :
        test_dataset = TSPLIBGraphDataset(folder_path=os.path.join(config.storage_path,"data/tsp_custom/tsplib"))
        test_dataloader= GraphDataLoader(test_dataset,batch_size=1,shuffle=False)
    else :
        
        test_dataloader = model.test_dataloader()
    # rewards_mean, rewards_std = - model.test_dataset.cost_mean, model.test_dataset.cost_std
    model_args = copy.deepcopy(model.args)
    model_diffusion = model.diffusion
    
    model, optimizer= accelerator.prepare(model, optimizer)

    # accelerator.load_state('~/pcb/rl_finetuning/ddpo_co/logs/tsp100_steps7_bs4,st3,ti2_lr0.001_useenv_last_2024.05.08_13.56/checkpoints/checkpoint_1')

    # top_checkpoints = []
    best_wo_2opt_score = np.inf
    best_w_2opt_score = np.inf

    # def save_checkpoint(model, optimizer, epoch, cost):
    #     checkpoint_path = "checkpoint_epoch_{}_cost_{:.4f}.pt".format(epoch, cost)
    #     accelerator.save_state(checkpoint_path)
        
    #     # 상위 5개 체크포인트 리스트 업데이트
    #     top_checkpoints.append((checkpoint_path, cost))
    #     top_checkpoints.sort(key=lambda x: x[1])  # 손실(loss)을 기준으로 정렬
    #     if len(top_checkpoints) > 5:
    #         # 상위 5개를 초과하는 체크포인트 삭제
    #         old_checkpoint, _ = top_checkpoints.pop()
    #         os.remove(old_checkpoint)

    # # 체크포인트 로드 함수
    # def load_checkpoint(checkpoint_path):
    #     accelerator.load_state(checkpoint_path)

    # # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
    # # remote server running llava inference.
    # # executor = futures.ThreadPoolExecutor(max_workers=2)

    # Train!
    samples_per_epoch = (
        config.batch_size
        * accelerator.num_processes
        * config.sample.num_iters_per_epoch
    )
    total_train_batch_size = (
        config.batch_size
        * accelerator.num_processes
        * config.train.gradient_accumulation_steps
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num Epochs = {config.num_epochs}")
    logger.info(f"  Sample batch size per device = {config.batch_size}")
    logger.info(f"  Train batch size per device = {config.batch_size}")
    logger.info(
        f"  Gradient Accumulation steps = {config.train.gradient_accumulation_steps}"
    )
    logger.info("")
    logger.info(f"  Total number of samples per epoch = {samples_per_epoch}")
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}"
    )
    logger.info(
        f"  Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}"
    )
    logger.info(f"  Number of inner epochs = {config.train.num_inner_epochs}")

    assert config.batch_size >= config.batch_size
    assert config.batch_size % config.batch_size == 0

    assert samples_per_epoch % total_train_batch_size == 0
    _ = None
    global_step = 0
    first_epoch = 0
    eval_count = 0
    # import pdb
    # pdb.set_trace()
    if config.resume_from:
        accelerator.load_state(config.resume_from)
        logger.info(f"Resuming from {config.resume_from}")
    time2 = time.perf_counter()
    print("Initialization requires", time2-time_init)
    test_scores = []
    for _ in range(4):
        score = evaluate_test(model, model_diffusion, model_args, test_dataloader, _, accelerator,global_step, sparse, use_env=False, inference=True, reward_2opt=config.reward_2opt)
        test_scores.append(score)

    print('test_scores',np.mean(test_scores), test_scores)

    #     exit(0)
    # else:

    time_eval = time.perf_counter()
    exit(0)
   
    
import gc
def print_gpu_memory_usage(string, model=None):
    print("GPU Memory Usage:" + string)
    if not model==None:
        total_memory = 0
        
        # 총 메모리 용량 계산
        for name, param in model.named_parameters():
            if param.is_cuda:
                total_memory += param.element_size() * param.nelement()
        
        print(f" Total GPU Memory Usage: {total_memory / 1024 / 1024:.2f} MB")
        
        # 메모리 사이즈가 큰 순서대로 정렬하여 출력
        memory_usage = []
        for name, param in model.named_parameters():
            if param.is_cuda:
                memory = param.element_size() * param.nelement()
                memory_usage.append((name, param.size(), memory))
        
        memory_usage.sort(key=lambda x: x[2], reverse=True)
        
        for name, size, memory in memory_usage:
            print(f" {name} ({size}) - {memory / 1024 / 1024:.2f} MB")
    else:
        for obj in gc.get_objects():
            if torch.is_tensor(obj) and obj.is_cuda:
                print(f" {type(obj).__name__} ({obj.size()}) - {obj.element_size() * obj.nelement() / 1024 / 1024:.2f} MB")

def evaluate_test(model,model_diffusion,model_args,test_dataloader, _,accelerator,global_step,sparse,use_env,inference, reward_2opt):
    """
    """
    solution_costs = []
    old_costs = []
    all_sl_loss = []
    model.eval()
    # with torch.no_grad():
    time1 = time.perf_counter()
    for batch in test_dataloader :
        solved_cost, old_cost= test_eval_co(model,model_diffusion,model_args,batch,inference,use_env,sparse)
        if model_args.print_sl_loss:
            sl_loss = model.categorical_training_step(batch, 0).detach()
        solution_costs += solved_cost
        old_costs += old_cost
        # all_wo_2opt_costs += wo_2opt_costs
        # all_solution_costs += solution_costs
        # print('sl_loss',sl_loss.to('cpu').numpy())
        if model_args.print_sl_loss:
            all_sl_loss += [sl_loss]
        # print('gt_costs,wo_2opt_costs,solution_costs',gt_costs,wo_2opt_costs,solution_costs)
    
    log_cost = {
                "solved_cost": np.mean(solution_costs),
                "old_cost": np.mean(old_costs)
            }
    if model_args.print_sl_loss:
        log_cost['test_supervised_loss'] = torch.tensor(all_sl_loss).mean()

    print('time',time.perf_counter()-time1, log_cost)
    # if not print_log:
    # accelerator.log(
    #         log_cost,
    #         step=global_step,
    #     )
    # if reward_2opt:
    return np.mean(solution_costs)




def evaluate(model,model_diffusion,model_args,test_dataloader, _,accelerator,global_step,sparse,use_env,inference, reward_2opt):
    all_gt_costs,all_wo_2opt_costs,all_solution_costs = [],[],[]
    all_sl_loss = []
    model.eval()
    # with torch.no_grad():
    time1 = time.perf_counter()
    for batch in test_dataloader :
        gt_costs,wo_2opt_costs,solution_costs = difusco_with_logprob(model,model_diffusion,model_args,batch,inference,use_env,sparse)
        if model_args.print_sl_loss:
            sl_loss = model.categorical_training_step(batch, 0).detach()
        all_gt_costs += gt_costs
        all_wo_2opt_costs += wo_2opt_costs
        all_solution_costs += solution_costs
        # print('sl_loss',sl_loss.to('cpu').numpy())
        if model_args.print_sl_loss:
            all_sl_loss += [sl_loss]
        # print('gt_costs,wo_2opt_costs,solution_costs',gt_costs,wo_2opt_costs,solution_costs)
    
    log_cost = {
                "test_gt_costs": np.mean(all_gt_costs),
                "test_model_costs": np.mean(all_solution_costs),
                "test_model_costs_wo2opt": np.mean(all_wo_2opt_costs),
            }
    if model_args.print_sl_loss:
        log_cost['test_supervised_loss'] = torch.tensor(all_sl_loss).mean()

    print('time',time.perf_counter()-time1, log_cost)
    # if not print_log:
    accelerator.log(
            log_cost,
            step=global_step,
        )
    # if reward_2opt:
    return np.mean(all_wo_2opt_costs), np.mean(all_solution_costs)

            
def load_model(config):
    """
    Load model and model calss
    """
    from difusco.pl_tsp_model import TSPModel
    from difusco.pl_tsp_classifier_guided_model import TSPGuidedModel
    from difusco.pl_tsp_reward_weighted_model import TSPReward_Weighted_Model
    from difusco.pl_mis_model import MISModel
    from difusco.pl_tsp_model_free import TSPModelFreeGuide
    from difusco.pl_pctsp_model_free import PCTSPModelFreeGuide
    from difusco.pl_pctsp_model import PCTSPModel

    # if config.parallel_sampling!=1:
    #     raise Exception("parallel_sampling is removed!")
    if config.task == "tsp" or config.task == 'tsplib':
        if config.diffusion_type == "gaussian" or config.diffusion_type == "categorical":
            if config.return_condition :
                model_class = TSPModelFreeGuide
            else :
                model_class = TSPModel
            saving_mode = "min"
        elif config.diffusion_type == "classifier":
            model_class = TSPGuidedModel
            saving_mode = "min"
        elif config.diffusion_type == "reward":
            model_class = TSPReward_Weighted_Model
            saving_mode = "min"
        else:
            raise NotImplementedError
        
    elif config.task == "mis":
        model_class = MISModel
        saving_mode = "max"
    
    elif config.task == "pctsp":
        if config.diffusion_type == "gaussian" or config.diffusion_type == "categorical":
            if config.return_condition :
                model_class = PCTSPModelFreeGuide
            else :
                model_class = PCTSPModel
            saving_mode = "min"

        else:
            raise NotImplementedError

    else: 
        raise NotImplementedError

    model = model_class.load_from_checkpoint(config.ckpt_path, param_args=config)
    #model = model_class(param_args=config)

    return model, model_class, saving_mode

def main(_):
    if FLAGS.use_sweep:
        # Sweep 구성 설정
        sweep_config = {
            'method': 'random',
            'metric': {
                'name': 'reward_mean',
                'goal': 'maximize'
            },
            'parameters': {
                'train.learning_rate': {
                    'values': [3e-7, 1e-6, 3e-6]
                },
                # 추가적인 하이퍼파라미터를 지정할 수 있습니다.
            }
        }
        print('sweep mode', sweep_config)

        # Sweep 초기화
        sweep_id = wandb.sweep(sweep_config, project='ddpo-pytorch')

        # train_with_sweep 함수를 app.run()으로 실행
        wandb.agent(sweep_id, function=train_with_sweep)
    else:
        # Sweep을 사용하지 않고 훈련 수행
        train_with_sweep()

if __name__ == "__main__":
    
    app.run(main)