import time
import torch
from tqdm import tqdm
from .utils import AverageMeter
from torch.cuda.amp import autocast
import torch.nn.functional as F
import torch.distributed as dist

def train(train_config, model, dataloader, loss_function, optimizer, scheduler=None, scaler=None):
    
    model.train()
    
    losses = AverageMeter()
    
    
    time.sleep(0.1)
    
    
    optimizer.zero_grad(set_to_none=True)
    
    step = 1
    
    
    if train_config.verbose and train_config.rank == 0:
        bar = tqdm(dataloader, total=len(dataloader))
    else:
        bar = dataloader
    
    
    for query, reference, ids in bar:
        
        if scaler:
            with autocast():
                
                query = query.to(train_config.device)
                reference = reference.to(train_config.device)
            
                
                features1, features2 = model(query, reference)
                
                
                logit_scale = model.module.logit_scale.exp() if hasattr(model, 'module') else model.logit_scale.exp()
                loss = loss_function(features1, features2, logit_scale)
                losses.update(loss.item())
                  
            scaler.scale(loss).backward()
            
            
            if train_config.clip_grad:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_value_(model.parameters(), train_config.clip_grad) 
            
            
            scaler.step(optimizer)
            scaler.update()

            
            optimizer.zero_grad()
            
            
            if train_config.scheduler in ["polynomial", "cosine", "constant"]:
                scheduler.step()
   
        else:
            
            query = query.to(train_config.device)
            reference = reference.to(train_config.device)

            
            features1, features2 = model(query, reference)
            
            
            logit_scale = model.module.logit_scale.exp() if hasattr(model, 'module') else model.logit_scale.exp()
            loss = loss_function(features1, features2, logit_scale)
            losses.update(loss.item())

            
            loss.backward()
            
            
            if train_config.clip_grad:
                torch.nn.utils.clip_grad_value_(model.parameters(), train_config.clip_grad)                  
            
            
            optimizer.step()
            optimizer.zero_grad()
            
            
            if train_config.scheduler in ["polynomial", "cosine", "constant"]:
                scheduler.step()
        
        
        if train_config.verbose and train_config.rank == 0:
            monitor = {
                "loss": "{:.4f}".format(loss.item()),
                "loss_avg": "{:.4f}".format(losses.avg),
                "lr": "{:.6f}".format(optimizer.param_groups[0]['lr'])
            }
            bar.set_postfix(ordered_dict=monitor)
        
        step += 1

    if train_config.verbose and train_config.rank == 0:
        bar.close()

    
    avg_loss = torch.tensor(losses.avg).to(train_config.device)
    dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
    avg_loss = avg_loss.item() / train_config.world_size

    return avg_loss


def predict(train_config, model, dataloader):
    model.eval()
    
    time.sleep(0.1)
    
    if train_config.verbose and train_config.rank == 0:
        bar = tqdm(dataloader, total=len(dataloader))
    else:
        bar = dataloader
        
    img_features_list = []
    ids_list = []
    
    with torch.no_grad():
        for img, ids in bar:
            ids_list.append(ids)
            
            
            if hasattr(train_config, 'mixed_precision') and train_config.mixed_precision:
                with autocast():
                    img = img.to(train_config.device)
                    img_feature = model(img)
            else:
                img = img.to(train_config.device)
                img_feature = model(img)
            
            
            if train_config.normalize_features:
                img_feature = F.normalize(img_feature, dim=-1)
        
            img_features_list.append(img_feature.to(torch.float32))
      
        
        img_features = torch.cat(img_features_list, dim=0) 
        ids_list = torch.cat(ids_list, dim=0).to(train_config.device)
        
    if train_config.verbose and train_config.rank == 0:
        bar.close()
        
    return img_features, ids_list