import torch
from tqdm import tqdm
from utils import AverageMeter, accuracy
import numpy as np
from .metrics import ECELoss, SCELoss, AdaptiveECELoss

def reduce_average_meter(meter):
    """
    Reduce an AverageMeter across all processes.
    Assumes meter has 'sum' and 'count' attributes.
    """
    tensor = torch.tensor([meter.sum, meter.count], device='cuda')
    torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
    # Compute the average from the reduced sum and count
    return tensor[0].item() / tensor[1].item()

def train(trainloader, model, optimizer, criterion):
    # switch to train mode
    model.train()

    losses = AverageMeter()
    top1 = AverageMeter()

    bar = tqdm(enumerate(trainloader), total=len(trainloader))
    for batch_idx, (inputs, targets) in bar:
        inputs, targets = inputs.cuda(), targets.cuda()

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, = accuracy(outputs.data, targets.data, topk=(1, ))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # plot progress locally (each process will print its own progress)
        bar.set_postfix_str('({batch}/{size}) Loss: {loss:.8f} | top1: {top1: .4f}'.format(
            batch=batch_idx + 1,
            size=len(trainloader),
            loss=losses.avg,
            top1=top1.avg
        ))

    # For distributed training, reduce the metrics from all processes.
    if torch.distributed.is_initialized():
        global_loss = reduce_average_meter(losses)
        global_top1 = reduce_average_meter(top1)
    else:
        global_loss, global_top1 = losses.avg, top1.avg

    return (global_loss, global_top1)

@torch.no_grad()
def test(testloader, model, criterion):
    losses = AverageMeter()
    top1 = AverageMeter()
    top3 = AverageMeter()
    top5 = AverageMeter()

    all_targets = []
    all_outputs = []

    # switch to evaluate mode
    model.eval()

    bar = tqdm(enumerate(testloader), total=len(testloader))
    for batch_idx, (inputs, targets) in bar:
        inputs, targets = inputs.cuda(), targets.cuda()

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        prec1, prec3, prec5 = accuracy(outputs.data, targets.data, topk=(1, 3, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top3.update(prec3.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # Convert to CPU numpy arrays for gathering later
        all_targets.append(targets.cpu().numpy())
        all_outputs.append(outputs.cpu().numpy())

        bar.set_postfix_str('({batch}/{size}) Loss: {loss:.8f} | top1: {top1: .4f} | top3: {top3: .4f} | top5: {top5: .4f}'.format(
            batch=batch_idx + 1,
            size=len(testloader),
            loss=losses.avg,
            top1=top1.avg,
            top3=top3.avg,
            top5=top5.avg,
        ))

    # Concatenate local predictions
    all_outputs = np.concatenate(all_outputs, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    # Gather predictions and targets from all processes if in distributed mode.
    if torch.distributed.is_initialized():
        # Use all_gather_object to collect lists from all ranks
        gathered_outputs = [None for _ in range(torch.distributed.get_world_size())]
        gathered_targets = [None for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather_object(gathered_outputs, all_outputs)
        torch.distributed.all_gather_object(gathered_targets, all_targets)
        all_outputs = np.concatenate(gathered_outputs, axis=0)
        all_targets = np.concatenate(gathered_targets, axis=0)

        # Also reduce the scalar metrics
        global_loss = reduce_average_meter(losses)
        global_top1 = reduce_average_meter(top1)
        global_top3 = reduce_average_meter(top3)
        global_top5 = reduce_average_meter(top5)
    else:
        global_loss, global_top1, global_top3, global_top5 = losses.avg, top1.avg, top3.avg, top5.avg

    # Compute calibration metrics on the global outputs and targets
    ece = ECELoss().loss(all_outputs, all_targets, n_bins=15)
    sce = SCELoss().loss(all_outputs, all_targets, n_bins=15)
    aece = AdaptiveECELoss().forward(all_outputs, all_targets)

    return (global_loss, global_top1, global_top3, global_top5, sce, ece, aece)

@torch.no_grad()
def get_logits_from_model_dataloader(testloader, model):
    """Returns torch tensors of logits and targets on CPU, aggregated over all processes."""
    model.eval()
    local_outputs = []
    local_targets = []

    bar = tqdm(testloader, total=len(testloader), desc="Evaluating logits")
    for inputs, targets in bar:
        inputs = inputs.cuda()
        outputs = model(inputs)
        local_outputs.append(outputs.cpu())
        local_targets.append(targets.cpu())

    local_outputs = torch.cat(local_outputs, dim=0)
    local_targets = torch.cat(local_targets, dim=0)

    # Prepare placeholders for gathering from all processes
    world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
    outputs_list = [torch.zeros_like(local_outputs) for _ in range(world_size)]
    targets_list = [torch.zeros_like(local_targets) for _ in range(world_size)]

    if torch.distributed.is_initialized():
        torch.distributed.all_gather(outputs_list, local_outputs)
        torch.distributed.all_gather(targets_list, local_targets)
        global_outputs = torch.cat(outputs_list, dim=0)
        global_targets = torch.cat(targets_list, dim=0)
    else:
        global_outputs, global_targets = local_outputs, local_targets

    return global_outputs, global_targets

