import os
import time
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from sklearn import metrics
import wandb

from myutils import AverageMeter, accuracy
from myDataLoader import get_train_val_test_loader

# =========================
# Helpers
# =========================

def _unwrap_logits(output):
    """If model returns (logits, feat) or [logits, ...], return logits."""
    if isinstance(output, (tuple, list)):
        return output[0]
    return output


def adjust_curlr_beta(epoch, args, optimizer=None):
    """
    FIXED LR schedule:
    - Uses args.local_lr (if provided), otherwise args.lr.
    - Warmup for first 5 epochs.
    - For 120 epochs (CIFAR), step decay at 60 and 90 epochs:
        [0..60): base
        [60..90): base*0.1
        [90..end): base*0.01
    - Keeps args.curbeta = args.beta (your original behavior).
    """
    e = epoch + 1  # 1-based epoch like your original function

    # keep beta behavior
    args.curbeta = args.beta

    # choose correct base lr
    base_lr = float(getattr(args, "local_lr", args.lr))

    # warmup
    if e <= 5:
        lr = base_lr * e / 5.0
    else:
        if args.epochs in (120, 150):
            if e > 90:
                lr = base_lr * 0.01
            elif e > 60:
                lr = base_lr * 0.1
            else:
                lr = base_lr
        elif args.epochs == 90:
            if e > 60:
                lr = base_lr * 0.1
            else:
                lr = base_lr
        elif args.epochs == 60:
            if e >= 30:
                lr = base_lr * 0.1
            else:
                lr = base_lr
        elif args.epochs in (30, 40):
            if e >= 20:
                lr = base_lr * 0.1
            else:
                lr = base_lr
        elif args.epochs == 3:
            lr = base_lr * 0.1 if e > 2 else base_lr
        elif args.epochs == 2:
            lr = base_lr * 0.1 if e > 1 else base_lr
        elif args.epochs == 200:
            if e >= 180:
                lr = base_lr * 0.0001
                args.beta = args.beta * 0.0001
            elif e >= 160:
                lr = base_lr * 0.01
                args.beta = args.beta * 0.01
            else:
                lr = base_lr
        else:
            # generic fallback
            lr = base_lr

    args.curlr = float(lr)

    if optimizer is not None:
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.curlr


# =========================
# Validation / Forward
# =========================

def forward(args, data_loader, model_new, criterion, epoch=0,
            training=True, optimizer=None, var=False):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()
    auc_score = 0

    use_auc = getattr(args, "auc", False)
    if use_auc:
        target_list = np.array([])
        pred_target_list = np.array([])

    torch.manual_seed(777 + epoch)

    device = next(model_new.parameters()).device

    for i, (_, inputs, target) in enumerate(data_loader):
        data_time.update(time.time() - end)

        inputs = inputs.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        if training:
            output = model_new(inputs)
        else:
            with torch.no_grad():
                output = model_new(inputs)

        logits = _unwrap_logits(output)

        loss = criterion(logits, target)

        if use_auc:
            target_list = np.concatenate((target_list, target.detach().cpu().numpy()))
            _, pred_target = torch.max(logits.detach(), 1)
            pred_target_list = np.concatenate((pred_target_list, pred_target.cpu().numpy()))

        prec1, prec5 = accuracy(logits.detach(), target, topk=(1, 5))

        losses.update(loss.item(), inputs.size(0))
        top1.update(float(prec1), inputs.size(0))
        top5.update(float(prec5), inputs.size(0))

        if training and optimizer is not None:
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

    if use_auc:
        fpr, tpr, _ = metrics.roc_curve(target_list, pred_target_list, pos_label=1)
        auc_score = metrics.auc(fpr, tpr)

    return losses.avg, top1.avg, top5.avg, auc_score


def forward_cifar_val_loader(args, data_loader, model_new, criterion, epoch=0,
                            training=True, optimizer=None, var=False):
    # same as forward(), but kept for your existing calls
    return forward(args, data_loader, model_new, criterion, epoch,
                   training=training, optimizer=optimizer, var=var)


def validate_cifar_val_loader(args, data_loader, model_new, criterion, epoch):
    model_new.eval()
    return forward_cifar_val_loader(args, data_loader, model_new, criterion, epoch,
                                   training=False, optimizer=None, var=False)


def validate(args, data_loader, model_new, criterion, epoch):
    model_new.eval()
    return forward(args, data_loader, model_new, criterion, epoch,
                   training=False, optimizer=None, var=False)


# =========================
# Algorithms
# =========================

def FastDRO(args, model_new, results):
    wandb.init(config=vars(args), project="P", entity="aditi1_cse-wayne-state-university")

    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    train_loader, val_loader, test_loader = get_train_val_test_loader(args, None)

    ivd_criterion = nn.CrossEntropyLoss(reduction='none')
    CE_criterion = nn.CrossEntropyLoss()

    optimizerW = torch.optim.SGD(
        model_new.parameters(),
        lr=args.lr,
        momentum=0.9,
        weight_decay=args.weight_decay
    )

    device = next(model_new.parameters()).device

    best_test_acc1 = 0
    start_time = time.time()

    for epoch in range(args.resumed_epoch, args.epochs):
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        model_new.train()
        adjust_curlr_beta(epoch, args, optimizerW)

        for batch_idx, (_, inputs, targets) in enumerate(train_loader):
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            out = model_new(inputs)
            outputs = _unwrap_logits(out)

            ivd_loss = ivd_criterion(outputs, targets)
            max_loss = torch.max(ivd_loss).item()
            lamda = binary_search_FastDRO_lambda(args, ivd_loss)

            stb_loss = ivd_loss - max_loss
            exploss = torch.exp(stb_loss / (args.lamda0 + lamda))
            p = exploss / torch.sum(exploss)
            p = p.detach()

            droloss = torch.sum(p * ivd_loss)

            optimizerW.zero_grad(set_to_none=True)
            droloss.backward()
            optimizerW.step()

            acc1, acc5 = accuracy(outputs.detach(), targets, topk=(1, 5))
            loss = CE_criterion(outputs, targets)

            losses.update(loss.item(), inputs.size(0))
            top1.update(float(acc1), inputs.size(0))
            top5.update(float(acc5), inputs.size(0))

            if batch_idx % args.print_freq == 0 and args.epochs > 10:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Train Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Train Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                        epoch, batch_idx, len(train_loader),
                        loss=losses, top1=top1, top5=top5))

        # epoch-end eval
        if args.epochs > 10:
            train_loss, train_prec1, train_prec5, train_auc_score = validate(args, train_loader, model_new, CE_criterion, epoch)

            if 'cifar' in args.dataset:
                val_loss, val_prec1, val_prec5, val_auc_score = validate_cifar_val_loader(args, val_loader, model_new, CE_criterion, epoch)
            else:
                val_loss, val_prec1, val_prec5, val_auc_score = validate(args, val_loader, model_new, CE_criterion, epoch)

            if test_loader is not None:
                test_loss, test_prec1, test_prec5, test_auc_score = validate(args, test_loader, model_new, CE_criterion, epoch)
            else:
                test_loss, test_prec1, test_prec5 = val_loss, val_prec1, val_prec5

            overall_running_time = (time.time() - start_time) // 60
            best_test_acc1 = max(best_test_acc1, test_prec1)

            results.add(epoch=epoch, val_loss=val_loss,
                        train_prec1=train_prec1, val_prec1=val_prec1,
                        train_prec5=train_prec5, val_prec5=val_prec5,
                        test_prec1=test_prec1, test_prec5=test_prec5,
                        overall_running_time=overall_running_time)
            results.save()

            wandb.log({"lr": args.curlr, "train_loss": train_loss,
                       "train_acc1": train_prec1, "val_acc1": val_prec1,
                       "test_acc1": test_prec1, "best_test_acc": best_test_acc1}, step=epoch)


def PDSGDDRO(args, model_new, results):
    wandb.init(config=vars(args), project="P", entity="aditi1_cse-wayne-state-university")

    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    train_loader, val_loader, test_loader = get_train_val_test_loader(args, None)

    device = next(model_new.parameters()).device

    # Build class weights if available
    if hasattr(train_loader.dataset, 'targets'):
        targets = train_loader.dataset.targets
        class_counts = torch.bincount(torch.tensor(targets))
        total_samples = class_counts.sum().item()
        class_weights = total_samples / (len(class_counts) * class_counts.clamp(min=1))
        class_weights = class_weights.float().to(device)
    else:
        class_weights = torch.ones(args.num_classes, device=device)

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    criterion_DRO = nn.CrossEntropyLoss(reduction='none')

    optimizerW = torch.optim.SGD(model_new.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizerW, step_size=5, gamma=0.5)

    numRounds = len(train_loader) * args.epochs
    data_len = len(train_loader.dataset)
    number_of_runs_per_epoch = len(train_loader)

    if 'cifar' in args.dataset:
        P_length = 50000
    else:
        P_length = len(train_loader.dataset)

    indices = None
    if 'cifar' in args.dataset and hasattr(train_loader.dataset, "indices"):
        indices = train_loader.dataset.indices
        print('length of training data:', len(indices))

    DROP = torch.tensor([1 / data_len] * P_length, device=device)
    uniP = torch.tensor([1 / data_len] * P_length, device=device)

    best_prec1, lamda = 0, 0
    rounds, epoch = 0, 0
    start_time = time.time()

    while rounds <= numRounds:
        if rounds % number_of_runs_per_epoch == 0:
            print('{0}/{1}, {2} epochs finished'.format(rounds, numRounds, rounds // number_of_runs_per_epoch))

        adjust_curlr_beta(epoch, args, optimizerW)
        model_new.train()

        for batch_idx, (minibatchIdx, inputs, targets) in enumerate(train_loader):
            rounds += 1

            otherbatchIdx = list(set(list(range(P_length))) - set(minibatchIdx.tolist()))

            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            out = model_new(inputs)
            logits = _unwrap_logits(out)

            loss_DRO = criterion_DRO(logits, targets)

            unbiased_loss_DRO = loss_DRO * data_len / len(minibatchIdx)
            loss = torch.sum(DROP[minibatchIdx] * unbiased_loss_DRO)

            optimizerW.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model_new.parameters(), max_norm=1.0)
            optimizerW.step()

            DROP, lamda = PD_updatesP(args, DROP, uniP, minibatchIdx, otherbatchIdx,
                                      unbiased_loss_DRO, indices)
            scheduler.step()

        epoch += 1

        if args.epochs > 10:
            train_loss, train_prec1, train_prec5, train_auc_score = validate(args, train_loader, model_new, criterion, epoch)

            if 'cifar' in args.dataset:
                val_loss, val_prec1, val_prec5, val_auc_score = validate_cifar_val_loader(args, val_loader, model_new, criterion, epoch)
            else:
                val_loss, val_prec1, val_prec5, val_auc_score = validate(args, val_loader, model_new, criterion, epoch)

            if test_loader is not None:
                test_loss, test_prec1, test_prec5, test_auc_score = validate(args, test_loader, model_new, criterion, epoch)
            else:
                test_loss, test_prec1, test_prec5 = val_loss, val_prec1, val_prec5

            best_prec1 = max(best_prec1, test_prec1)

            overall_training_time = (time.time() - start_time) // 60
            results.add(epoch=epoch, train_loss=train_loss, val_loss=val_loss,
                        train_prec1=train_prec1, val_prec1=val_prec1, test_prec1=test_prec1,
                        train_prec5=train_prec5, val_prec5=val_prec5, test_prec5=test_prec5,
                        overall_training_time=overall_training_time, auc_score=train_auc_score)
            results.save()

            wandb.log({"lr": args.curlr, "train_loss": train_loss,
                       "train_acc1": train_prec1, "val_acc1": val_prec1,
                       "test_acc1": test_prec1, "best_test_acc": best_prec1,
                       "lambda_constrained": float(lamda)}, step=epoch)


def MBSGD(args, model_new, results):
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    train_loader, val_loader, test_loader = get_train_val_test_loader(args, None)
    CE_criterion = nn.CrossEntropyLoss()

    optimizerW = torch.optim.SGD(model_new.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)

    wandb.init(config=vars(args), project="P", entity="aditi1_cse-wayne-state-university")

    best_acc1 = 0
    start_time = time.time()

    device = next(model_new.parameters()).device

    for epoch in range(args.resumed_epoch, args.epochs):
        model_new.train()
        adjust_curlr_beta(epoch, args, optimizerW)

        for batch_idx, (_, inputs, targets) in enumerate(train_loader):
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            out = model_new(inputs)
            logits = _unwrap_logits(out)

            loss = CE_criterion(logits, targets)

            optimizerW.zero_grad(set_to_none=True)
            loss.backward()
            optimizerW.step()

        train_loss, train_prec1, train_prec5, _ = validate(args, train_loader, model_new, CE_criterion, epoch)

        if 'cifar' in args.dataset:
            val_loss, val_prec1, val_prec5, _ = validate_cifar_val_loader(args, val_loader, model_new, CE_criterion, epoch)
        else:
            val_loss, val_prec1, val_prec5, _ = validate(args, val_loader, model_new, CE_criterion, epoch)

        if test_loader is not None:
            test_loss, test_prec1, test_prec5, _ = validate(args, test_loader, model_new, CE_criterion, epoch)
        else:
            test_loss, test_prec1, test_prec5 = val_loss, val_prec1, val_prec5

        best_acc1 = max(best_acc1, test_prec1)
        overall_running_time = (time.time() - start_time) // 60

        results.add(epoch=epoch, val_loss=val_loss,
                    train_prec1=train_prec1, val_prec1=val_prec1 if test_loader is None else test_prec1,
                    train_prec5=train_prec5, val_prec5=val_prec5 if test_loader is None else test_prec5,
                    overall_running_time=overall_running_time)
        results.save()

        print(f"Epoch {epoch}/{args.epochs} | train_acc1={train_prec1:.3f} val_acc1={val_prec1:.3f} test_acc1={test_prec1:.3f}")

        wandb.log({"lr": args.curlr, "train_loss": train_loss,
                   "train_acc1": train_prec1, "val_acc1": val_prec1,
                   "test_acc1": test_prec1, "best_test_acc": best_acc1}, step=epoch)


# =========================
# FastDRO / PD helper functions
# =========================

def FastDROLambdamax(args, ivd_loss):
    lamda = 0
    while True:
        if FastDRO_KL(args, ivd_loss, lamda) < args.rho:
            return lamda
        lamda = lamda * 2 + 1


def binary_search_FastDRO_lambda(args, ivd_loss):
    rlambda = FastDROLambdamax(args, ivd_loss)
    llambda = rlambda / 2

    if FastDRO_KL(args, ivd_loss, rlambda) < args.rho:
        return rlambda

    while True:
        mid = (rlambda + llambda) / 2
        val = FastDRO_KL(args, ivd_loss, mid)
        if val > args.rho + 1e-3:
            llambda = mid
        elif val < args.rho - 1e-3:
            rlambda = mid
        else:
            return mid


def FastDRO_KL(args, ivd_loss, lamda):
    max_loss = torch.max(ivd_loss).item()
    stb_loss = ivd_loss - max_loss
    exploss = torch.exp(stb_loss / (args.lamda0 + lamda))
    p = exploss / torch.sum(exploss)
    p = p.detach()
    uniP = torch.ones_like(exploss) / args.batch_size
    kl_val = F.kl_div(p.log(), uniP, reduction='sum')
    return kl_val


def PD_updatesP(args, DROP, uniP, minibatchIdx, otherbatchIdx, loss, indices=None):
    max_loss = torch.max(loss).detach()
    stb_loss = loss - max_loss

    disparity = PD_KL(args, DROP, uniP, minibatchIdx, otherbatchIdx, 0, loss, indices)
    if math.isinf(disparity):
        return copy.deepcopy(DROP), 0

    if disparity <= args.rho:
        lamda = 0
    else:
        lamda = binary_search_PD_lambda(args, DROP, uniP, minibatchIdx, otherbatchIdx, loss, indices)

    newP = torch.ones_like(DROP)
    newP[minibatchIdx] = (DROP[minibatchIdx] * torch.exp(args.plr * stb_loss)).pow(
        1 / ((lamda + args.lamda0) * args.plr + 1)
    )
    newP[otherbatchIdx] = (DROP[otherbatchIdx] * torch.exp(-args.plr * max_loss)).pow(
        1 / ((lamda + args.lamda0) * args.plr + 1)
    )
    newP = newP.detach()

    if indices is not None:
        newP[indices] = newP[indices] / torch.sum(newP[indices])
    else:
        newP = newP / torch.sum(newP)

    return copy.deepcopy(newP), lamda


def binary_search_PD_lambda(args, P, uniP, minibatchIdx, otherbatchIdx, loss, indices):
    l, r = 0, 1
    while PD_KL(args, P, uniP, minibatchIdx, otherbatchIdx, r, loss, indices) > args.rho:
        l = r
        r = 2 * r + 1

    while True:
        mid = (l + r) / 2
        val = PD_KL(args, P, uniP, minibatchIdx, otherbatchIdx, mid, loss, indices)
        if val > args.rho + 1e-3:
            l = mid
        elif val < args.rho - 1e-3:
            r = mid
        else:
            return mid


def PD_KL(args, P, uniP, minibatchIdx, otherbatchIdx, lamda, loss, indices):
    max_loss = torch.max(loss).detach()
    stb_loss = loss - max_loss

    newP = torch.zeros_like(P)
    newP[minibatchIdx] = (P[minibatchIdx] * torch.exp(args.plr * stb_loss)).pow(
        1 / ((lamda + args.lamda0) * args.plr + 1)
    )
    newP[otherbatchIdx] = (P[otherbatchIdx] * torch.exp(-args.plr * max_loss)).pow(
        1 / ((lamda + args.lamda0) * args.plr + 1)
    )

    if indices is not None:
        newP[indices] = newP[indices] / torch.sum(newP[indices])
        return F.kl_div(newP[indices].log(), uniP[indices], reduction='sum')

    newP = newP / torch.sum(newP)
    return F.kl_div(newP.log(), uniP, reduction='sum')