import os
import json
import time
import wandb
import torch
import transformers
import numpy as np
import torch.distributed as dist

from tqdm import tqdm
from safetensors.torch import load_file
from datetime import timedelta
from torch.utils.data import DataLoader
from accelerate import (
    Accelerator, 
    DistributedDataParallelKwargs,
)

from loss import (
    SupConLoss,
    RouterLoss,
)
from data import create_dataloaders
from model import SiameseEncoder
from utils import (
    set_seed, args_parser,
)


def main():

    args = args_parser()
    set_seed(args.seed)

    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], log_with = 'wandb')
    DEVICE = accelerator.device

    accelerator.init_trackers(
        project_name = args.project_name,
        config = args,
        init_kwargs = {"wandb":{"name": args.run_name}},
    )

    accelerator.print(args)
    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)

    accelerator.print('Loading Data...')
    train_dataloader = create_dataloaders(
        data_dir = args.data_dir, 
        folder_name = args.folder_name,
        do_pretrain = args.do_pretrain,
        batch_size = args.batch_size,
        num_workers = args.num_workers,
        dataset_info = json.load(open(args.dataset_info_path)),
    )
    accelerator.wait_for_everyone()

    accelerator.print('Building Model, Optimizer and Loss...')
    model = SiameseEncoder( 
        embedding_dim = args.embedding_dim,
        vocab_size = args.vocab_size,
        max_orig_positional_len = args.max_orig_positional_len,
        hidden_size = args.hidden_size,
        num_hidden_layers = args.num_hidden_layers,
        num_attention_heads = args.num_attention_heads,
        intermediate_size = args.intermediate_size,
        intermediate_size_expert = args.intermediate_size_expert,
        num_expert_heads = args.num_expert_heads,
        pad_token_id = args.pad_token_id,
        hidden_act = args.hidden_act,
        token_moe = args.token_moe,
        moe_type = args.moe_type,
        topk = args.topk,
        hash_list_path = args.hash_list_path,
        num_experts = args.num_experts,
        num_sparse_layers = args.num_sparse_layers,
        gradient_checkpointing = args.gradient_checkpointing,
    )
    if not args.do_pretrain:
        state_dict = load_file(args.load_model)
        model.load_state_dict(state_dict)
    
    accelerator.print(model)
    model.to(DEVICE)
    
    optimizer = torch.optim.AdamW(
        [param for _, param in model.named_parameters()],
        lr = args.lr,
        weight_decay = args.weight_decay,
        eps = args.adam_epsilon,
    )
    loss_supcon = SupConLoss(temperature = args.temperature)

    t_total = len(train_dataloader) * args.epochs
    if accelerator.is_main_process:
        accelerator.print(f'Total number of training steps: {t_total}')
    scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    model, optimizer, train_dataloader, scheduler = accelerator.prepare(model, optimizer, train_dataloader, scheduler)

    start_ep = 1
    step = 1
    best_val_loss = float('inf')

    accelerator.wait_for_everyone()
    accelerator.print('Training...')

    for ep in range(start_ep, args.epochs + 1):
        start_time = time.time()
        total_loss = 0.0
        model.train()
        torch.set_grad_enabled(True)
        accelerator.print(f'Started epoch {ep} at {time.ctime(start_time)}')
        pbar = tqdm(train_dataloader, desc=f'Epoch {ep}', disable=not accelerator.is_local_main_process)

        for batch in pbar:

            optimizer.zero_grad()

            logs = {}
            query_emb, passage_emb, query_router_logits, passage_router_logits = model(batch, mean_pooling = args.mean_pooling)
            loss_supcon_val = loss_supcon(query_emb, passage_emb, batch['qp_mat'].to(DEVICE))
            logs['supcon_loss'] = loss_supcon_val.item()
            loss = loss_supcon_val
            logs['loss'] = loss.item()

            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()
            total_loss += loss.detach()

            if step % 10 == 0:
                accelerator.log(logs, step=step)
            pbar.set_description(f'Epoch {ep}, Loss: {loss.item()}')
            
            step += 1

        accelerator.save_state(os.path.join(args.output_dir, f'epoch_{ep}.pt'))
        dist.all_reduce(total_loss, op=dist.ReduceOp.AVG)
        accelerator.print(f'Epoch {ep} completed. Loss: {total_loss:.4f}, Time taken: {time.time() - start_time:.2f} seconds')
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        accelerator.save(unwrapped_model.state_dict(), os.path.join(args.output_dir, f'epoch_{ep}_unwrapped.pt'))
    
    if accelerator.is_main_process:
        accelerator.print('Training completed.')
    accelerator.end_training()


if __name__ == '__main__':
    main()