# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional

import torch
import torch.nn.functional as F
import numpy as np
from torch import vmap

from timm.data import Mixup
from timm.utils import accuracy, ModelEma

from losses import DistillationLoss
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import utils
from typing import Tuple
from poly.wd_regularization_torch import polynomial_regularization, precompute_chebyshev_matrix


def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True, args = None):
    
    # >>>>> NEW: Determine AMP dtype based on args.bf16 <<<<<
    amp_dtype = torch.float16
    if args is not None and getattr(args, 'bf16', False):
        amp_dtype = torch.bfloat16

    def _compute_wd_reg_from_batch(
        images: torch.Tensor,
        targets: torch.Tensor,
    ) -> Tuple[torch.Tensor, float, bool]:
        """Compute polynomial weighted-degree regularization for ImageNet training.

        Reuses:
        - poly.wd_regularization_torch.polynomial_regularization
        - utils.label_for_samples / utils.pca 

        Returns: (reg_term_tensor, lambda_reg_current, did_compute)
        """
        lambda_reg_base = float(utils.adjust_lambda_reg_sin(epoch, args))
        if getattr(args, "lambda_reg", 0.0) <= 0:
            return torch.tensor(0.0, device=device), 0.0, False

        resolution = int(getattr(args, "resolution", 0))
        if resolution < 2:
            return torch.tensor(0.0, device=device), lambda_reg_base, False

        num_pairs_requested = int(getattr(args, "nums_pairs", 1))
        num_pairs = max(1, num_pairs_requested)
        batch_size = images.size(0)
        if batch_size < 2:
            return torch.tensor(0.0, device=device), lambda_reg_base, False

        # Ensure enough distinct samples; mimic CIFAR safety check.
        can_compute = (batch_size * (batch_size - 1) >= num_pairs * 2)
        if not can_compute:
            return torch.tensor(0.0, device=device), lambda_reg_base, False

        max_degree = int(getattr(args, "max_degree", 40))
        miu = float(getattr(args, "miu", 0.0))
        have_const = not bool(getattr(args, "remove_const", False))
        use_norm = bool(getattr(args, "use_norm", False))
        square = bool(getattr(args, "square", False))
        degree_mode = getattr(args, "degree_mode", "index")
        random_alpha = bool(getattr(args, "random_alpha", False))
        use_label = bool(getattr(args, "label", False))
        do_smooth = bool(getattr(args, "smooth", False))
        pca_k = int(getattr(args, "pca_reg", 0) or 0)

        # ---- sample pairs (x1, x2), allow repeats but ensure x2 != x1 ----
        x1_idx = torch.randint(low=0, high=batch_size, size=(num_pairs,), device=device)
        offset = torch.randint(low=1, high=batch_size, size=(num_pairs,), device=device)
        x2_idx = (x1_idx + offset) % batch_size

        x1 = images[x1_idx]  # [P,C,H,W]
        x2 = images[x2_idx]  # [P,C,H,W]

        # ---- alpha nodes on [-1,1] (Chebyshev domain) ----
        if random_alpha:
            n_inner = resolution - 2
            if n_inner > 0:
                steps = torch.arange(n_inner, device=device, dtype=torch.float32)
                jitter = torch.rand((num_pairs, n_inner), device=device, dtype=torch.float32)
                raw_fraction = (steps.unsqueeze(0) + jitter) / float(n_inner)

                padding = 1.0 / float(resolution)
                compressed = padding + raw_fraction * (1.0 - 2.0 * padding)
                theta = compressed * np.pi
                alpha_inner = -torch.cos(theta)
                alpha_inner, _ = torch.sort(alpha_inner, dim=1)
            else:
                alpha_inner = torch.empty((num_pairs, 0), device=device, dtype=torch.float32)

            alpha_full = torch.cat(
                [
                    -torch.ones((num_pairs, 1), device=device, dtype=torch.float32),
                    alpha_inner,
                    torch.ones((num_pairs, 1), device=device, dtype=torch.float32),
                ],
                dim=1,
            )  # [P,R]
        else:
            cached = precompute_chebyshev_matrix(resolution, max_degree, device)
            alpha_values = cached["alpha_values"].to(torch.float32)
            alpha_full = alpha_values.unsqueeze(0).expand(num_pairs, -1)  # [P,R]

        # ---- map alpha to image interpolation coefficient ----
        # NOTE:
        # - For Chebyshev fitting we always keep alpha_full in [-1,1].
        # - For image mixup coefficient we follow CIFAR logic:
        #   label=True  -> map to [0,1]
        #   label=False -> map to [-0.1, 0.5] (extrapolation)
        if use_label:
            alpha_mix_full = (alpha_full + 1.0) * 0.5
        else:
            alpha_mix_full = -0.1 + (alpha_full + 1.0) * 0.6

        # ---- build samples and forward ----
        # Efficiency: if use_label, we only forward inner points (endpoints become hard labels).
        if use_label and resolution > 2:
            alpha_mix_inner = alpha_mix_full[:, 1:-1]  # [P,R-2]
            alpha_exp = alpha_mix_inner.view(num_pairs, resolution-2, 1, 1, 1)
            with torch.cuda.amp.autocast(dtype=amp_dtype):
                
                samples = x1.unsqueeze(1) + alpha_exp * (x2.unsqueeze(1) - x1.unsqueeze(1))
                samples_flat = samples.reshape(-1, *images.shape[1:])
            # alpha_exp = alpha_mix_inner.view(num_pairs, resolution - 2, 1, 1, 1)
            # samples = x1.unsqueeze(1) + alpha_exp * (x2.unsqueeze(1) - x1.unsqueeze(1))
            # samples_flat = samples.reshape(-1, *images.shape[1:])
                # samples_flat = samples_flat.to(device, non_blocking=True)

            # >>>>> BF16 Check: Forward pass uses amp_dtype, but output is cast to float below <<<<<
            # with torch.cuda.amp.autocast(dtype=amp_dtype):
                logits_inner = model(samples_flat)

            # >>>>> IMPORTANT: .float() here ensures REG calculation remains in FP32 <<<<<
            probs_inner = F.softmax(logits_inner.float(), dim=1)
            probs_inner = probs_inner.view(num_pairs, resolution - 2, -1)

            num_classes = probs_inner.size(-1)
            zeros_endpoint = torch.zeros((num_pairs, 1, num_classes), device=device, dtype=probs_inner.dtype)
            full_sequence = torch.cat([zeros_endpoint, probs_inner, zeros_endpoint], dim=1)  # [P,R,C]

            # Hard-label endpoints via utils.label_for_samples
            if targets.dim() != 1:
                return torch.tensor(0.0, device=device), lambda_reg_base, False
            label1 = targets[x1_idx].long()
            label2 = targets[x2_idx].long()
            alpha_mix_full_exp = alpha_mix_full.view(num_pairs, resolution, 1, 1, 1)
            full_sequence = utils.label_for_samples(
                alpha_mix_full_exp,
                full_sequence,
                num_pairs,
                label1,
                label2,
                device,
            )
        else:
            alpha_exp = alpha_mix_full.view(num_pairs, resolution, 1, 1, 1)
            samples = x1.unsqueeze(1) + alpha_exp * (x2.unsqueeze(1) - x1.unsqueeze(1))
            samples_flat = samples.reshape(-1, *images.shape[1:])
            
            # >>>>> BF16 Check <<<<<
            with torch.cuda.amp.autocast(dtype=amp_dtype):
                logits = model(samples_flat)
            
            # >>>>> IMPORTANT: .float() ensures REG calculation is FP32 <<<<<
            probs = F.softmax(logits.float(), dim=1)
            full_sequence = probs.view(num_pairs, resolution, -1)

        if pca_k > 0:
            full_sequence = utils.pca(full_sequence, num_pairs, k=pca_k)

        # ---- vmap compute reg per pair ----
        def _reg_wrapper(alpha_vec, y_seq):
            return polynomial_regularization(
                alpha_vec,
                y_seq,
                resolution,
                miu,
                max_degree,
                have_const,
                use_norm,
                random_alpha,
                square,
                degree_mode,
            )

        alpha_in_dims = 0  # [P,R]
        y_in_dims = 0      # [P,R,C] or [P,R,1]
        reg_terms = vmap(_reg_wrapper, in_dims=(alpha_in_dims, y_in_dims))(alpha_full, full_sequence)
        reg_term = torch.mean(reg_terms)

        # Optional annealing by loss scale (match CIFAR behavior)
        if bool(getattr(args, "reg_anneal", False)):
            train_loss_max = math.log(float(getattr(args, "nb_classes", 1000)))
            # Use EMA loss stored on args if present, else fall back to base lambda.
            ema_loss = float(getattr(args, "_ema_train_loss", 0.0))
            if ema_loss > 0:
                min_lambda = float(getattr(args, "min_lambda_reg", 0.0))
                lambda_reg_base = max(min_lambda, float(getattr(args, "lambda_reg", 0.0)) * ema_loss / train_loss_max)

        return reg_term, float(lambda_reg_base), True

    model.train(set_training_mode)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('reg', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('lambda_reg', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10
    
    if args.cosub:
        criterion = torch.nn.BCEWithLogitsLoss()

    # EMA for reg_anneal (kept on args so helper can read it)
    ema_alpha = 0.1
    setattr(args, "_ema_train_loss", 0.0)
        
    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        # Keep originals for WD-regularization (before mixup/label transforms)
        samples_for_reg = samples
        targets_for_reg = targets

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)
            
        if args.cosub:
            samples = torch.cat((samples,samples),dim=0)
            
        if args.bce_loss:
            targets = targets.gt(0.0).type(targets.dtype)

        # >>>>> BF16 Check: Main forward pass <<<<<
        with torch.cuda.amp.autocast(dtype=amp_dtype):
            outputs = model(samples)
            if not args.cosub:
                loss = criterion(samples, outputs, targets)
            else:
                outputs = torch.split(outputs, outputs.shape[0]//2, dim=0)
                loss = 0.25 * criterion(outputs[0], targets) 
                loss = loss + 0.25 * criterion(outputs[1], targets) 
                loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid())
                loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) 

        # Update EMA loss for reg_anneal (use base loss without reg)
        base_loss_value = float(loss.detach().cpu().item())
        prev_ema = float(getattr(args, "_ema_train_loss", 0.0))
        if prev_ema <= 0:
            setattr(args, "_ema_train_loss", base_loss_value)
        else:
            setattr(args, "_ema_train_loss", ema_alpha * base_loss_value + (1.0 - ema_alpha) * prev_ema)

        # WD regularization (compute with gradients)
        # Internal forward pass in BF16 (if active); results are cast to FP32 for downstream computation.
        reg_term, lambda_reg_current, did_reg = _compute_wd_reg_from_batch(samples_for_reg, targets_for_reg)
        if did_reg and lambda_reg_current > 0:
            loss = loss + lambda_reg_current * reg_term

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        loss_scaler(loss, optimizer, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=is_second_order)

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(reg=float(reg_term.detach().cpu().item()))
        metric_logger.update(lambda_reg=float(lambda_reg_current))
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(data_loader, model, device, args=None):
    criterion = torch.nn.CrossEntropyLoss()

    # >>>>> NEW: Determine AMP dtype based on args.bf16 <<<<<
    amp_dtype = torch.float16
    if args is not None and getattr(args, 'bf16', False):
        amp_dtype = torch.bfloat16

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    for images, target in metric_logger.log_every(data_loader, 10, header):
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        # >>>>> BF16 Check <<<<<
        with torch.cuda.amp.autocast(dtype=amp_dtype):
            output = model(images)
            loss = criterion(output, target)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}