import os
import json
import torch
from tqdm import tqdm
from multiprocessing import Pool
from pathlib import Path
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast
import logging
from pytorch_lightning import seed_everything
import copy
import wandb

from utils import read_json, AverageMeterSet, Ranker
from optimization import *
from recformer import RecformerModel, RecformerForPretraining2, RecformerTokenizer, RecformerConfig, Predictor
from collator import PretrainDataCollatorWithPadding2, EvalDataCollatorWithPadding
from dataloader import RecformerTrainDataset, RecformerEvalDataset

def load_data(args):
    train = read_json(os.path.join(args.data_path, args.train_file), True)
    # val = read_json(os.path.join(args.data_path, args.dev_file), True)
    # test = read_json(os.path.join(args.data_path, args.test_file), True)
    item_meta_dict = json.load(open(os.path.join(args.data_path, args.meta_file)))
    
    item2id = read_json(os.path.join(args.data_path, args.item2id_file))
    id2item = {v:k for k, v in item2id.items()}

    item_meta_dict_filted = {k: v for k, v in item_meta_dict.items() if k in item2id}

    return train, item_meta_dict_filted, item2id, id2item

tokenizer_glb: RecformerTokenizer = None
def _par_tokenize_doc(doc):
    item_id, item_attr = doc
    input_ids, token_type_ids = tokenizer_glb.encode_item(item_attr)
    return item_id, input_ids, token_type_ids

def train_one_epoch_warmup_predictor(model, predictor, target, dataloader, optimizer, scheduler, scaler, epoch, args):
    predictor.train()
    
    for step, batch in enumerate(tqdm(dataloader, ncols=100, desc='Training')):
        for k, v in batch.items():
            try:
                batch[k] = v.to(args.device)
            except:
                batch[k] = v

        if args.fp16:
            with autocast():
                output = model(**batch, is_context=True)
                predict_mask_embedding_target = predictor(output.h_CLS_history, output.h_CLS_target_mask, True)
                predict_mask_embedding_history = predictor(output.h_CLS_history, output.item_position_embeddings_jepa, False)
                loss = output.loss
                jepa_l2_loss = torch.nn.MSELoss()
                target_embedding, masked_item_history_embedding = target(**batch, is_context=False)
                loss = loss + model.config.mlm_weight * jepa_l2_loss(predict_mask_embedding_target, target_embedding) + \
                       model.config.mlm_weight * jepa_l2_loss(predict_mask_embedding_history, masked_item_history_embedding)
        else:
            output = model(**batch, is_context=True)
            predict_mask_embedding_target = predictor(output.h_CLS_history, output.h_CLS_target_mask, True)
            predict_mask_embedding_history = predictor(output.h_CLS_history, output.item_position_embeddings_jepa, False)
            loss = output.loss
            jepa_l2_loss = torch.nn.MSELoss()
            target_embedding, masked_item_history_embedding = target(**batch, is_context=False)
            loss = loss + model.config.mlm_weight * jepa_l2_loss(predict_mask_embedding_target, target_embedding) + \
                   model.config.mlm_weight * jepa_l2_loss(predict_mask_embedding_history, masked_item_history_embedding)

        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        if args.fp16:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        if (step + 1) % args.gradient_accumulation_steps == 0:
            if args.fp16:
                scale_before = scaler.get_scale()
                scaler.step(optimizer)
                scaler.update()
                scale_after = scaler.get_scale()
                optimizer_was_run = scale_before <= scale_after
                optimizer.zero_grad()
                if optimizer_was_run:
                    scheduler.step()
            else:
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()

def train_one_epoch(model, predictor, target, dataloader, optimizer, scheduler, momentum_scheduler, scaler, epoch, args):
    model.train()
    predictor.train()
    
    # Track losses for wandb logging
    epoch_losses = []
    epoch_acc = []
    
    for step, batch in enumerate(tqdm(dataloader, ncols=100, desc='Training')):
        for k, v in batch.items():
            try:
                batch[k] = v.to(args.device)
            except:
                batch[k] = v

        if args.fp16:
            with autocast():
                output = model(**batch, is_context=True)
                predict_mask_embedding_target = predictor(output.h_CLS_history, output.h_CLS_target_mask, True)
                predict_mask_embedding_history = predictor(output.h_CLS_history, output.item_position_embeddings_jepa, False)
                loss = output.loss
                jepa_l2_loss = torch.nn.MSELoss()
                target_embedding, masked_item_history_embedding = target(**batch, is_context=False)
                loss = loss + model.config.mlm_weight * jepa_l2_loss(predict_mask_embedding_target, target_embedding) + \
                       model.config.mlm_weight * jepa_l2_loss(predict_mask_embedding_history, masked_item_history_embedding)
        else:
            output = model(**batch, is_context=True)
            predict_mask_embedding_target = predictor(output.h_CLS_history, output.h_CLS_target_mask, True)
            predict_mask_embedding_history = predictor(output.h_CLS_history, output.item_position_embeddings_jepa, False)
            loss = output.loss
            jepa_l2_loss = torch.nn.MSELoss()
            target_embedding, masked_item_history_embedding = target(**batch, is_context=False)
            loss = loss + model.config.mlm_weight * jepa_l2_loss(predict_mask_embedding_target, target_embedding) + \
                   model.config.mlm_weight * jepa_l2_loss(predict_mask_embedding_history, masked_item_history_embedding)

        # Store the loss for wandb logging
        epoch_losses.append(loss.item())
        epoch_acc.append(output.cl_correct_num/output.cl_total_num)
        
        # Log batch loss periodically to wandb
        if step % 50 == 0:
            wandb.log({
                "batch_loss": loss.item(),
                "batch_acc": output.cl_correct_num/output.cl_total_num,
                "epoch": epoch,
                "step": step + epoch * len(dataloader)
            })

        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        if args.fp16:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        if (step + 1) % args.gradient_accumulation_steps == 0:
            if args.fp16:
                scale_before = scaler.get_scale()
                scaler.step(optimizer)
                scaler.update()
                scale_after = scaler.get_scale()
                optimizer_was_run = scale_before <= scale_after
                optimizer.zero_grad()
                if optimizer_was_run:
                    scheduler.step()
            else:
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()

        with torch.no_grad():
            m = next(momentum_scheduler)
            for param_q, param_k in zip(model.parameters(), target.parameters()):
                param_k.data.mul_(m).add_((1.-m) * param_q.detach().data)
    
    # Log average epoch loss to wandb
    avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
    avg_epoch_acc = sum(epoch_acc) / len(epoch_acc)
    wandb.log({
        "epoch": epoch,
        "train_loss": avg_epoch_loss,
        "train_acc": avg_epoch_acc,
        "learning_rate": scheduler.get_last_lr()[0]
    })
    return avg_epoch_acc

def main():
    parser = ArgumentParser()
    # path and file
    parser.add_argument('--data_path', type=str, default=None, required=True)
    parser.add_argument('--output_dir', type=str, default='full_ckpt/checkpoint_jepa_pretrain_10epochs')
    parser.add_argument('--ckpt', type=str, default='best_model.bin')
    parser.add_argument('--pred_ckpt', type=str, default='warmup_predictor.bin')
    parser.add_argument('--model_name_or_path', type=str, default='allenai/longformer-base-4096')
    parser.add_argument('--train_file', type=str, default='train.json')
    parser.add_argument('--dev_file', type=str, default='val.json')
    parser.add_argument('--test_file', type=str, default='test.json')
    parser.add_argument('--item2id_file', type=str, default='smap.json')
    parser.add_argument('--meta_file', type=str, default='meta_data.json')

    # data process
    parser.add_argument('--preprocessing_num_workers', type=int, default=8)
    parser.add_argument('--dataloader_num_workers', type=int, default=4)

    # model
    parser.add_argument('--temp', type=float, default=0.05)

    # train
    parser.add_argument('--num_train_epochs', type=int, default=16)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=8)
    parser.add_argument('--finetune_negative_sample_size', type=int, default=1000)
    parser.add_argument('--metric_ks', nargs='+', type=int, default=[10, 50])
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--learning_rate', type=float, default=5e-5)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--warmup_steps', type=int, default=100)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--fix_word_embedding', action='store_true')
    parser.add_argument('--warmup_predictor', action='store_true')
    parser.add_argument('--warmup_predictor_epochs', type=int, default=3)
    parser.add_argument('--verbose', type=int, default=2)
    parser.add_argument('--mlm_probability', type=float, default=0.15)
    parser.add_argument('--longformer_ckpt', type=str, default='longformer_ckpt/longformer-base-4096.bin')
    
    # WandB arguments
    parser.add_argument('--wandb_project', type=str, default='jepa4rec-pretraining')
    parser.add_argument('--wandb_entity', type=str, default="abcminhvangiang-viettel")
    parser.add_argument('--wandb_run_name', type=str, default="jepa no pos embed")
    parser.add_argument('--no_wandb', action='store_true', help='Disable wandb logging')

    args = parser.parse_args()
    print(args)
    seed_everything(42)
    args.device = torch.device('cuda:{}'.format(args.device)) if args.device >= 0 else torch.device('cpu')

    # Initialize WandB
    if not args.no_wandb:
        wandb_run_name = args.wandb_run_name or f"recformer_bs{args.batch_size}_lr{args.learning_rate}"
        wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=wandb_run_name,
            config={
                "learning_rate": args.learning_rate,
                "architecture": "RecformerForPretraining2",
                "dataset": Path(args.data_path).name,
                "epochs": args.num_train_epochs,
                "batch_size": args.batch_size,
                "mlm_probability": args.mlm_probability,
                "finetune_negative_sample_size": args.finetune_negative_sample_size,
                "temp": args.temp,
                "gradient_accumulation_steps": args.gradient_accumulation_steps,
                "fix_word_embedding": args.fix_word_embedding,
                "model_name_or_path": args.model_name_or_path,
            }
        )

    # Enable cuDNN benchmarking
    torch.backends.cudnn.benchmark = True

    train, item_meta_dict, item2id, id2item = load_data(args)

    config = RecformerConfig.from_pretrained(args.model_name_or_path)
    config.max_attr_num = 3
    config.max_attr_length = 32
    config.max_item_embeddings = 51
    config.attention_window = [64] * 12
    config.max_token_num = 1024
    config.item_num = len(item2id)
    config.finetune_negative_sample_size = args.finetune_negative_sample_size
    tokenizer = RecformerTokenizer.from_pretrained(args.model_name_or_path, config)
    
    global tokenizer_glb
    tokenizer_glb = tokenizer

    path_corpus = Path(args.data_path)
    dir_preprocess = path_corpus / 'preprocess_jepa'
    dir_preprocess.mkdir(exist_ok=True)

    path_output = Path(args.output_dir) / 'movies'
    path_output.mkdir(exist_ok=True, parents=True)
    path_ckpt = path_output / args.ckpt
    path_predictor_ckpt = path_output / args.pred_ckpt

    path_tokenized_items = dir_preprocess / f'tokenized_items_{path_corpus.name}'

    if path_tokenized_items.exists():
        print(f'[Preprocessor] Use cache: {path_tokenized_items}')
    else:
        print(f'Loading attribute data {path_corpus}')
        pool = Pool(processes=args.preprocessing_num_workers)
        pool_func = pool.imap(func=_par_tokenize_doc, iterable=item_meta_dict.items())
        doc_tuples = list(tqdm(pool_func, total=len(item_meta_dict), ncols=100, desc=f'[Tokenize] {path_corpus}'))
        tokenized_items = {item2id[item_id]: [input_ids, token_type_ids] for item_id, input_ids, token_type_ids in doc_tuples}
        pool.close()
        pool.join()

        torch.save(tokenized_items, path_tokenized_items)

    tokenized_items = torch.load(path_tokenized_items)
    print(f'Successfully load {len(tokenized_items)} tokenized items.')

    train_collator = PretrainDataCollatorWithPadding2(tokenizer, tokenized_items, mlm_probability=args.mlm_probability, mask_jepa_prob=0.45)
    # test_collator = EvalDataCollatorWithPadding(tokenizer, tokenized_items)
    
    train_data = RecformerTrainDataset(train, collator=train_collator)
    # test_data = RecformerEvalDataset(train, val, test, mode='test', collator=test_collator)
    
    train_loader = DataLoader(train_data, 
                              batch_size=args.batch_size, 
                              shuffle=True, 
                              collate_fn=train_data.collate_fn,
                              num_workers=args.dataloader_num_workers,
                              pin_memory=True)
    
    # test_loader = DataLoader(test_data, 
    #                         batch_size=args.batch_size, 
    #                         collate_fn=test_data.collate_fn,
    #                         num_workers=args.dataloader_num_workers,
    #                         pin_memory=True)

    model = RecformerForPretraining2(config, id2item)
    # model.load_state_dict(torch.load(path_ckpt))
    model.load_state_dict(torch.load(args.longformer_ckpt))
    
    model.to(args.device)
    predictor = Predictor(config.hidden_size).to(args.device)
    # predictor.load_state_dict(torch.load(path_predictor_ckpt))

    if args.fix_word_embedding:
        print('Fix word embeddings.')
        for param in model.longformer.embeddings.word_embeddings.parameters():
            param.requires_grad = False

    model.to(args.device)

    target_encoder = copy.deepcopy(model)
    target_encoder = target_encoder.to(args.device)
    for p in target_encoder.parameters():
        p.requires_grad = False

    num_train_optimization_steps = int(len(train_loader) / args.gradient_accumulation_steps) * (args.num_train_epochs + args.warmup_predictor_epochs)
    optimizer, scheduler = create_optimizer_and_scheduler_jepa(model, predictor, num_train_optimization_steps, args)

    # -- momentum schedule
    ipe = len(train_loader)
    num_epochs = args.num_train_epochs
    ipe_scale = 1.0
    ema = [0.996, 1.0]
    momentum_scheduler = (ema[0] + i*(ema[1]-ema[0])/(ipe*num_epochs*ipe_scale)
                          for i in range(int(ipe*num_epochs*ipe_scale)+1))
    
    if args.fp16:
        scaler = torch.cuda.amp.GradScaler()
    else:
        scaler = None

    #Warm up predictor
    if args.warmup_predictor:
        print('Warm Up Predictor.')
        for param in model.parameters():
            param.requires_grad = False
        for epoch in range(args.warmup_predictor_epochs):
            train_one_epoch_warmup_predictor(model, predictor, target_encoder, train_loader, optimizer, scheduler, scaler, epoch, args)
        for param in model.parameters():
            param.requires_grad = True
        if args.fix_word_embedding:
            for param in model.longformer.embeddings.word_embeddings.parameters():
                param.requires_grad = False
        torch.save(predictor.state_dict(), path_predictor_ckpt)
                
    # Thêm các biến để theo dõi early stopping
    best_acc = 0
    patience = 5
    patience_counter = 0
    
    # Trong vòng lặp training
    for epoch in range(args.num_train_epochs):
        # Gọi hàm train_one_epoch và lấy accuracy trung bình
        avg_epoch_acc = train_one_epoch(model, predictor, target_encoder, train_loader, optimizer, scheduler, momentum_scheduler, scaler, epoch, args)
        
        # Kiểm tra xem accuracy có tốt hơn không
        if avg_epoch_acc > best_acc:
            best_acc = avg_epoch_acc
            patience_counter = 0
            # Lưu checkpoint tốt nhất
            torch.save(model.state_dict(), path_ckpt)
            
            # path_output_epoch = Path(args.output_dir) / f'epoch_{epoch}'
            # path_output_epoch.mkdir(exist_ok=True, parents=True)
            # path_ckpt_epoch =path_output_epoch / args.ckpt
            # torch.save(model.state_dict(), path_ckpt_epoch)
            # Log to wandb
            if not args.no_wandb:
                wandb.log({
                    "best_acc": best_acc,
                    "best_epoch": epoch
                })
            print(f"Epoch {epoch}: New best accuracy: {best_acc:.4f}")
        else:
            patience_counter += 1
            print(f"Epoch {epoch}: Accuracy did not improve. Patience: {patience_counter}/{patience}")
                    
        # Kiểm tra early stopping
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            if not args.no_wandb:
                wandb.log({"early_stopped": True, "stopped_epoch": epoch})
            break
        
    print("Training completed!")
        
if __name__ == "__main__":
    main()