from options import get_experiment_config
from set_up import setup_experiment
from data import create_dataloaders
from models import create_cirqrs_models
import wandb
from statistics import mean


import math
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch.optim as optim
import os

import numpy as np

med_config_path = './blip_model/med_config.json'

def main():
    configs = get_experiment_config()
    export_root, configs = setup_experiment(configs)
    
    device = torch.device(f"cuda:{configs['device_idx']}") if torch.cuda.is_available() else "cpu"

    model, txt_processors = create_cirqrs_models(configs, device)

    train_dataloader, test_dataloaders, train_val_dataloaders = create_dataloaders(configs)

    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=configs['init_lr'],
                                  weight_decay=configs['weight_decay'])

    max_epochs = configs['epoch']
    init_topk = configs['init_topk']
    topk = init_topk

    model.train()
    scaler =  torch.cuda.amp.GradScaler()
    for epoch in range(max_epochs):
        model.train()
        # if rerank_warmup is True, rerank the dataset every max_epochs // 5
        if configs["use_rerank"]:
            mode = configs['rerank_mode'] # topk, lower_target_topk
            if (epoch % (max_epochs // configs["negative_definition_epoch_num"])) == 0:
                if not configs['rerank_warmup'] or (epoch != 0 and configs['rerank_warmup']):
                    train_dataloader.dataset.rerank_score(model, device, topk, txt_processors, configs, mode)
                    topk = topk // 2


        epoch_running_loss = 0.0
        cosine_lr_schedule(optimizer, epoch, max_epochs, configs["init_lr"], configs["init_lr"] / 100)
        train_dataloader_tqdm = tqdm(train_dataloader, desc="Epoch {}".format(epoch+1))
        for batch_idx, (ref_images, tar_images, negative_images, sentences) in enumerate(train_dataloader_tqdm):
            optimizer.zero_grad()

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                ref_images, tar_images, negative_images = ref_images.to(device), tar_images.to(
                    device), negative_images.to(device)

                sentences = [txt_processors["eval"](caption) for caption in sentences]
                scores = model(ref_images, tar_images, sentences, negative_images, configs['use_temp'])

                # Apply softmax along the second dimension (dim=1)
                log_softmax_scores = F.log_softmax(scores, dim=1) 

                target_probs = torch.zeros_like(log_softmax_scores)
                target_probs[:, 0] = 1.0  # p0 = 1 for each data point

                loss = F.kl_div(log_softmax_scores, target_probs, reduction='batchmean')

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_running_loss += loss.item()
            cur_loss = epoch_running_loss / (batch_idx + 1)
            train_dataloader_tqdm.set_postfix({'loss': cur_loss})

        train_results = {'loss' : cur_loss, 'lr' : get_current_lr(optimizer)}
        print(f"[EPOCH {epoch+1}/{max_epochs}] Loss : {train_results['loss']} lr : {train_results['lr']}")
        if configs["experiment_description"] != 'debug':
            wandb.log({"loss" : cur_loss, "epoch" : epoch})
            wandb.log({'lr' : get_current_lr(optimizer), "epoch": epoch})
            wandb.log({'temp': model.temp.item(), "epoch": epoch})
            wandb.log({'k': topk, "epoch": epoch})

        # Save the model for every 10 epochs
        if (epoch + 1) % 10 == 0 and configs['dataset'] == 'cirr':
            # Save the model
            saving_path = f"-"
            os.makedirs(saving_path, exist_ok=True)
            torch.save(model.state_dict(), f'{saving_path}/cirqrs.pth')
            append_log(f'{saving_path}/log.txt', f"Epoch : {epoch + 1}\n")


def warmup_lr_schedule(optimizer, step, warmup_step, max_step, init_lr):
    """Warmup the learning rate"""
    if step < warmup_step:
        lr = init_lr
    else:
        lr = init_lr - ((step - warmup_step) / (max_step - warmup_step)) * init_lr
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
    """Decay the learning rate"""
    lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def get_current_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def append_log(file_path, log_message):
    with open(file_path, 'a') as file:
        file.write(log_message + '\n')

if __name__ == '__main__':
    main()