import glob, time, sys
import torch
import torch.nn.functional as F
import torch.optim as opt
from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler
import argparse
import dgl
import os
sys.path.append('../')
from src.utils import BranchDataset, dgl_collate
from src.logger import Logger
from src.model import GCNN_Net

from src.utils import *
from src.model import *
from src.sampler import *
import pickle

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from datetime import datetime, timedelta

from collections import Counter

length_counter = Counter()

logger_obj = Logger
logger_obj.set_logger(env='web')   # 自动保存到带时间戳的文件中

logger = logger_obj.logger

print('torch.__version__', torch.__version__)
#srun -p sugon --gres=gpu:1 python behavior_cloning.py --code_id new_33_1 --weight_type 1 --lr 1e-3 --shrink_lr 0.3 --loss ce --batch_size 96 --debug no
def save_train_data(data, path):

    with open(path, 'wb') as f:
        pickle.dump(data, f)

def train_select(args, model, trainData, optimizer, epoch, log_interval, device):

    if args.code_id == 'new_20':
        return train_new_20_cons(args, model, trainData, optimizer, epoch, args.log_interval, device, args.alpha)
    elif args.code_id == 'new_21':
        return train_new_21(args, model, trainData, optimizer, epoch, args.log_interval, device)
    elif args.code_id == 'new_21_aug':
        return train_new_21(args, model, trainData, optimizer, epoch, args.log_interval, device)

    else:
        raise Exception("train_select can not find code_id")


def test_select(args, model, validData, optimizer, epoch, device):
    if args.code_id == 'new_20':
        return test_new_20_cons(args, model, validData, optimizer, epoch, device)
    elif args.code_id == 'new_21':
        return test_new_21(args, model, validData, optimizer, epoch, device)
    elif args.code_id == 'new_21_aug':
        return test_new_21(args, model, validData, optimizer, epoch, device)
    

    else:
        raise Exception("test_select can not find code_id")
        
def setup(rank, world_size):
    torch.cuda.set_device(rank)
    dist.init_process_group("nccl", rank=rank, world_size=world_size,init_method='env://',timeout=timedelta(seconds=20))
    

def cleanup():
    dist.destroy_process_group()

def main(args):

    # 创建 logs 文件夹（如果不存在）
    os.makedirs("train_logs", exist_ok=True)

    # 当前时间作为文件名
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    log_file_path = os.path.join("train_logs", f"loss_log_{timestamp}.txt")
    
    world_size = torch.cuda.device_count()
    rank = int(os.environ["LOCAL_RANK"])
    setup(rank, world_size)

    device = torch.device(f'cuda:{rank}')
    
    logger.info(f'data: {args.ins_type}')
    path = '/home/lutk/linglong-data/3.rl4my/data/samples/'

    if args.debug == 'yes':
        if args.code_id in ["new_21"]:#2.rl4bb/data
            train_list = [f'/home/lutk/linglong-data/2.rl4bb/data/samples/{args.ins_type}/new_20/train/sample_1.pkl',
                            f'/home/lutk/linglong-data/2.rl4bb/data/samples/{args.ins_type}/new_20/train/sample_2.pkl',
                            f'/home/lutk/linglong-data/2.rl4bb/data/samples/{args.ins_type}/new_20/train/sample_3.pkl',
                            f'/home/lutk/linglong-data/2.rl4bb/data/samples/{args.ins_type}/new_20/train/sample_4.pkl']
            valid_list = [f'/home/lutk/linglong-data/2.rl4bb/data/samples/{args.ins_type}/new_20/valid/sample_100.pkl']
        elif args.code_id in ["new_21_aug"]:#2.rl4bb/data
            train_list = [f'/home/lutk/milp/our_model/data/samples/{args.ins_type}/new_20/train/sample_1.pkl',
                            f'/home/lutk/milp/our_model/data/samples/{args.ins_type}/new_20/train/sample_2.pkl',
                            f'/home/lutk/milp/our_model/data/samples/{args.ins_type}/new_20/train/sample_3.pkl',
                            f'/home/lutk/milp/our_model/data/samples/{args.ins_type}/new_20/train/sample_4.pkl']
            valid_list = [f'/home/lutk/milp/our_model/data/samples/{args.ins_type}/new_20/valid/sample_100.pkl']
        
        else:
            train_list = [f'../data/samples/{args.ins_type}/{args.code_id}/train/sample_1.pkl',
                            f'../data/samples/{args.ins_type}/{args.code_id}/train/sample_2.pkl',
                            f'../data/samples/{args.ins_type}/{args.code_id}/train/sample_3.pkl',
                            f'../data/samples/{args.ins_type}/{args.code_id}/train/sample_4.pkl'
            ]
            valid_list = [f'../data/samples/{args.ins_type}/{args.code_id}/valid/sample_1.pkl']

    else:
        if args.code_id in ["new_21"]:#2.rl4bb/data
            train_list = glob.glob(f'/home/lutk/linglong-data/2.rl4bb/data/samples/{args.ins_type}/new_20/train/*.pkl')
            valid_list = glob.glob(f'/home/lutk/linglong-data/2.rl4bb/data/samples/{args.ins_type}/new_20/valid/*.pkl')
        elif args.code_id in ["new_21_aug"]:#2.rl4bb/data
            train_list = glob.glob(f'/home/lutk/milp/our_model/data/samples/{args.ins_type}/new_20/train/*.pkl')
            valid_list = glob.glob(f'/home/lutk/milp/our_model/data/samples/{args.ins_type}/new_20/valid/*.pkl')
        else:
            train_list = glob.glob(f'/home/lutk/linglong-data/2.rl4bb/data/samples/{args.ins_type}/{args.code_id}/train/*.pkl')
            valid_list = glob.glob(f'/home/lutk/linglong-data/2.rl4bb/data/samples/{args.ins_type}/{args.code_id}/valid/*.pkl')
    # train_list = train_list[:1000]
    # args.epoch_num = 1  
    logger.info('The length of train_list:{}'.format(len(train_list)))
    logger.info('The length of valid_list:{}'.format(len(valid_list)))

    # 创建
    out_dir = rf'../check_points/{args.ins_type}/{args.code_id}'
    os.makedirs(out_dir, exist_ok=True)
    
    class_sample_count = [22, 163, 430, 1385]
    weights = 1. / torch.tensor(class_sample_count, dtype=torch.float)
    labels = []
    low_threshold = 83.61
    high_threshold = 166.97
    for dir in train_list:
        sample = pickle.load(gzip.open(dir, 'rb'))['data']

        _, _, action_set, _, _, _ = sample
        
        # depth = pickle.load(gzip.open(dir, 'rb'))['node_depth']
        # depth_label = torch.where(depth <= 1, torch.tensor(0),
        #                torch.where(depth <= 4, torch.tensor(1),
        #                torch.where(depth <= 7, torch.tensor(2),
        #                            torch.tensor(3))))
        depth = len(action_set)
        length_counter[depth] += 1
        if (depth <= low_threshold):
            labels.append(0)
        elif(depth <= high_threshold):
            labels.append(1)
        # elif(depth <= 7):
        #     labels.append(2)
        else:
            labels.append(2)
            
    # print(length_counter)        
    sampler = DistributedProportionalSampler(labels=labels,  batch_size=args.batch_size, num_replicas=dist.get_world_size(), rank=rank)
    # sampler = BalancedBatchSampler(labels, batch_size=args.batch_size)
    # samples_weight = weights[labels]  # labels是你的数据集所有标签

    # sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)

    # # 在DataLoader中使用
    # train_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    
    # todo : 修改DataLoader 和 model
    
    print('alpha', args.alpha)
    
    trainData = DataLoader(BranchDataset_new_20_depth(train_list), 
                           batch_size=args.batch_size, 
                           collate_fn=balanced_collate_fn, 
                           sampler=sampler,
                           num_workers=4,
                        pin_memory=True,)
    validData = DataLoader(BranchDataset_new_20_depth(valid_list), 
                           batch_size=args.batch_size, 
                           shuffle=False, 
                           collate_fn=dgl_collate_new_20_depth,
                        num_workers=4,
                        pin_memory=True)
    model = GCNN_Net_new_20_cons(v_dim=17).to(device)
    model = DDP(model, device_ids=[rank])
    # if torch.cuda.device_count() > 1:
    #     print("Using", torch.cuda.device_count(), "GPUs!")
    #     model = torch.nn.DataParallel(model)
    # model = model.to(device)
    map_location = torch.device(f'cuda:{rank}')
    if args.load_model is not None:
        logger.info('load model from ../check_points/{}/{}/{}.pt'.format(args.ins_type,args.code_id,args.load_model))
        # map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
        if args.code_id in ["new_31","new_32"]:
            model.load_state_dict(torch.load(f'../check_points/{args.ins_type}/{args.code_id}/{args.model_id}/{args.load_model}.pt', map_location=map_location))
        else:
            model.load_state_dict(torch.load(f'../check_points/{args.ins_type}/{args.code_id}/{args.load_model}.pt', map_location=map_location))
        
        # model.load_state_dict(torch.load(model_path, map_location=map_location))
    optimizer = opt.Adam(model.parameters(), lr=args.lr)
    acc = 0
    count = 0
    lr = args.lr
    best_epoch = -1
    print('finish generating train and test dataset')



    epoch_list = []
    loss_list = []
    acc_list = []
    shrink_epoch_list = []
    train_data_dict = {}

    for epoch in range(args.epoch_num):
        sampler.set_epoch(epoch)
        
        epoch_list.append(epoch)
        train_data_dict[f"bc_{epoch + 1}"] = {}

        if args.model_id == "train":
            loss, accuracy = train_select(args, model, trainData, optimizer, epoch, args.log_interval, device)
            train_data_dict[f"bc_{epoch + 1}"]["train_loss"] = loss
            train_data_dict[f"bc_{epoch + 1}"]["train_accuracy"] = accuracy

        loss, accuracy, hard_acc = test_select(args, model, validData, optimizer, epoch, device)
        train_data_dict[f"bc_{epoch + 1}"]["test_loss"] = loss
        train_data_dict[f"bc_{epoch + 1}"]["test_accuracy"] = accuracy

        loss_list.append(loss)
        acc_list.append(accuracy)
        if args.model_id == "train":
            torch.save(model.module.state_dict(), f'../check_points/{args.ins_type}/{args.code_id}/bc_{epoch + 1}.pt')
        if acc < accuracy:
            best_epoch = epoch
            acc = accuracy
            count = 0
        else:
            count += 1
        # if (count + 1) % 20 == 0 and lr > 1e-4:
        #     shrink_epoch_list.append(epoch)
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] *= args.shrink_lr
        #         lr *= args.shrink_lr
        #     logger.info('Epoch {}: shrink learning rate to {}'.format(epoch, lr))
        #     # print('count', count, epoch, best_epoch)
        #     # model.load_state_dict(torch.load(best_ckpt, map_location={'cuda:0': f'cuda:{rank}'}))
        #     model.load_state_dict(torch.load(f'../check_points/{args.ins_type}/{args.code_id}/bc_{best_epoch+1}.pt', map_location=map_location))#{'cuda:0': f'cuda:{rank}'}
        #     if lr < 1e-5:
        #         break
        if args.model_id == "train" and rank ==0:
            get_loss_img(epoch_list, acc_list, loss_list, shrink_epoch_list, args)
        if (epoch==0):
            best_epoch = 0
        if rank == 0:
            save_train_data(train_data_dict, f'../check_points/{args.ins_type}/{args.code_id}/train_data.pkl')
        # save_train_data(train_data_dict, f'../check_points/{args.ins_type}/{args.code_id}/train_data.pkl')
    # logger.info('Training finished, best check point is at epoch {}, best accuracy is {}'.format(best_epoch+1, acc))
    if rank == 0:
        logger.info(f'Training finished, best check point is at epoch {best_epoch+1}, best accuracy is {acc}')
        if args.model_id == "train":
            get_loss_img(epoch_list, acc_list, loss_list, shrink_epoch_list, args)

    cleanup()
    # if args.model_id == "train":
    #     get_loss_img(epoch_list, acc_list, loss_list, shrink_epoch_list, args)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--device', help='GPU index', type=int, default=0)
    parser.add_argument('--ins_type', help='type for the instances', type=str, default='setcover_400r_1000c_0.05d_100mc_0se')
    parser.add_argument('--ins_config', help='configuration for the instances', type=str, default='train')
    parser.add_argument('--batch_size', type=int, default=10)
    parser.add_argument('--log_interval', type=int, default=5 )
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--shrink_lr', type=float, default=0.3)
    parser.add_argument('--load_model', type=str, default=None)
    parser.add_argument('--code_id', type=str, default='baseline')
    parser.add_argument('--debug', type=str, default='yes')
    parser.add_argument('--loss', type=str, default='ce')
    parser.add_argument('--model_id', type=str, default='train')
    parser.add_argument('--regu_mse', type=float, default=0.00003)
    parser.add_argument('--epoch_num', type=int, default=300)
    parser.add_argument('--test_data', type=str, default=None)
    parser.add_argument('--weight_type', type=int, default=0)
    parser.add_argument('--top_k', type=int, default=0)
    parser.add_argument('--alpha', type=float, default=0.5)
    parser.add_argument('--local_rank', type=int, default=0, help='local rank for DistributedDataParallel')

    args = parser.parse_args()
    main(args)
    
    
 
