import argparse
import random
import torch
import numpy as np
from time import time
import logging
import multiprocessing

from torch.utils.data import DataLoader

from datasets import EmbDataset
from models.rqvae import RQVAE
from trainer import  Trainer
import pickle
def parse_args():
    parser = argparse.ArgumentParser(description="Index")

    parser.add_argument('--lr', type=float, default=1e-5, help='learning rate')
    parser.add_argument('--epochs', type=int, default=5000, help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=5096, help='batch size')
    parser.add_argument('--num_workers', type=int, default=4, )
    parser.add_argument('--eval_step', type=int, default=50, help='eval step')
    parser.add_argument('--learner', type=str, default="AdamW", help='optimizer')
    parser.add_argument('--lr_scheduler_type', type=str, default="constant", help='scheduler')
    parser.add_argument('--warmup_epochs', type=int, default=50, help='warmup epochs')
    parser.add_argument("--data_path", type=str,
                        default="./data/ToolBench/toolweaver-mean-embeddings-output-sentences-sbert-atomic.npy",
                        help="Input data path.")
    parser.add_argument("--weight_decay", type=float, default=0.0, help='l2 regularization weight')
    parser.add_argument("--dropout_prob", type=float, default=0.0, help="dropout ratio")
    parser.add_argument("--bn", type=bool, default=False, help="use bn or not")
    parser.add_argument("--loss_type", type=str, default="mse", help="loss_type")
    parser.add_argument("--graph_loss", type=bool, default=True, help="Similarity Regularization")
    parser.add_argument("--graph_lambda", type=float, default=0.001, help="")
    parser.add_argument("--kmeans_init", type=bool, default=True, help="use kmeans_init or not")
    parser.add_argument("--kmeans_iters", type=int, default=100, help="max kmeans iters")
    parser.add_argument('--sk_epsilons', type=float, nargs='+', default=[0.01, 0.01], help="sinkhorn epsilons")
    parser.add_argument("--sk_iters", type=int, default=50, help="max sinkhorn iters")

    parser.add_argument("--device", type=str, default="cuda:3", help="gpu or cpu")

    parser.add_argument('--num_emb_list', type=int, nargs='+', default=[1024, 1024], help='emb num of every vq')
    parser.add_argument('--e_dim', type=int, default=64, help='vq codebook embedding size')
    parser.add_argument('--quant_loss_weight', type=float, default=1.0, help='vq quantion loss weight')
    parser.add_argument("--beta", type=float, default=2.25, help="Beta for commitment loss")
    parser.add_argument('--layers', type=int, nargs='+', default=[1024,512,256,128,64], help='hidden sizes of every layer')

    parser.add_argument('--save_limit', type=int, default=5)
    parser.add_argument("--ckpt_dir", type=str, default="", help="output directory for model")

    return parser.parse_args()

class SimilarityDataset(torch.utils.data.Dataset):
    def __init__(self, data, similarity_matrix):
        self.data = data
        # Keep similarity matrix on CPU
        self.similarity_matrix = torch.tensor(similarity_matrix, device='cpu')
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], idx
    
    def get_similarity_matrix(self, indices):
        # Only move the required sub-matrix to GPU
        sub_similarity_matrix = self.similarity_matrix[indices][:, indices]
        return sub_similarity_matrix
        
    def collate_fn(self, batch):
        """
        batch: [(data_1, idx_1), (data_2, idx_2), ... , (data_b, idx_b)]
        """
        data_list, idx_list = zip(*batch)              # 分别解包
        data_tensor = torch.stack(data_list, dim=0)    # (batch_size, ...)
        idx_list = torch.LongTensor(idx_list)
        
        # Get similarity matrix for this batch
        sub_similarity_matrix = self.get_similarity_matrix(idx_list)
        
        # Move tensors to the same device as the data
        device = data_tensor.device
        sub_similarity_matrix = sub_similarity_matrix.to(device)
        
        return data_tensor, sub_similarity_matrix

if __name__ == '__main__':
    """fix the random seed"""
    # Set multiprocessing start method to 'spawn'
    multiprocessing.set_start_method('spawn')
    
    seed = 2024
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    args = parse_args()
    print("=================================================")
    print(args)
    print("=================================================")

    logging.basicConfig(level=logging.DEBUG)

    """build dataset"""
    data = EmbDataset(args.data_path)
    model = RQVAE(in_dim=data.dim,
                  num_emb_list=args.num_emb_list,
                  e_dim=args.e_dim,
                  layers=args.layers,
                  dropout_prob=args.dropout_prob,
                  bn=args.bn,
                  loss_type=args.loss_type,
                  quant_loss_weight=args.quant_loss_weight,
                  beta=args.beta,
                  kmeans_init=args.kmeans_init,
                  kmeans_iters=args.kmeans_iters,
                  sk_epsilons=args.sk_epsilons,
                  sk_iters=args.sk_iters,
                  graph_lambda=args.graph_lambda
                  )
    print(model)
    with open('./data/similarity_matrix.pkl', 'rb') as f:
        similarity_matrix = pickle.load(f)
    similarity_matrix = similarity_matrix["tool_similarity"]["similarity_matrix"]
    similarity_dataset = SimilarityDataset(data, similarity_matrix)

    data_loader = DataLoader(similarity_dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers,
                            collate_fn=similarity_dataset.collate_fn,
                            pin_memory=True)

    trainer = Trainer(args,model, len(data_loader))
    best_loss, best_collision_rate = trainer.fit(data_loader)

    print("Best Loss",best_loss)
    print("Best Collision Rate", best_collision_rate)

