import argparse
import datetime
import os
import sys
import platform
import pprint
import random
from collections import defaultdict
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm import create_model
from torchvision import models
from torch.utils.data import Subset, DataLoader
from torch.optim.lr_scheduler import ChainedScheduler, LinearLR

from datasets_lt import get_dataloader, get_num_workers
from backbones import resnet32, wrn28, CosineClassifier
from loss_functions import BalancedSoftmaxLoss
from utils import WarmupMultiStepLR, WarmupCosineAnnealingLR

from sklearn.metrics import confusion_matrix, accuracy_score, classification_report

if platform.system() == "Linux":
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

_TIMESTAMP = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def count_classes_in_subset(subset: Subset, num_classes: int) -> List[int]:
    dataset = subset.dataset
    indices = subset.indices

    if hasattr(dataset, "targets"):
        labels = np.array(dataset.targets)
    elif hasattr(dataset, "labels"):
        labels = np.array(dataset.labels)
    else:
        raise ValueError("Dataset has no label attribute")

    subset_labels = labels[indices]
    class_counts = np.bincount(subset_labels, minlength=num_classes)
    return class_counts.tolist()


def create_imbalanced_subset(dataset, num_classes: int, threshold: int) -> Subset:
    if hasattr(dataset, "targets"):
        labels = dataset.targets
    elif hasattr(dataset, "labels"):
        labels = dataset.labels
    else:
        raise ValueError("Dataset has no label attribute")

    class_indices = defaultdict(list)
    for idx, label in enumerate(labels):
        class_indices[int(label)].append(idx)

    selected: List[int] = []
    for cls in range(num_classes):
        size = min(len(class_indices[cls]), threshold)
        chosen = np.random.choice(class_indices[cls], size=size, replace=False)
        selected.extend(chosen.tolist())

    return Subset(dataset, selected)


def get_class_groups(class_counts: List[int],
                     many_thresh: int = 100,
                     few_thresh: int = 20) -> dict:
    counts = np.array(class_counts)
    return {
        "many": np.where(counts >= many_thresh)[0].tolist(),
        "medium": np.where((counts >= few_thresh) & (counts < many_thresh))[0].tolist(),
        "few": np.where(counts < few_thresh)[0].tolist(),
    }


def get_model_paths(args, thresholds: List[int]) -> List[Path]:
    model_paths: List[Path] = []
    base_dir = Path(args.models_in)
    for count in thresholds:
        pattern = f"*{args.model}_{args.loss}_*c{count}*.pth"
        found = sorted(base_dir.glob(pattern), key=os.path.getmtime, reverse=True)
        if found:
            model_paths.append(found[0])
            print(f"[INFO] Using model for threshold {count}: {found[0]}")
        else:
            print(f"[WARNING] No model found for threshold {count}")
    print(f"[INFO] Found {len(model_paths)} models")
    if not model_paths:
        sys.exit("[ERROR] No checkpoints found. Abort.")
    return model_paths


@torch.no_grad()
def cache_logits(models_list: List[torch.nn.Module],
                 dataloader: DataLoader,
                 num_classes: int) -> Tuple[List[torch.Tensor], np.ndarray]:
    targets: List[int] = []
    all_logits: List[List[torch.Tensor]] = [[] for _ in range(len(models_list))]

    device = next(models_list[0].parameters()).device
    for x, y in dataloader:
        x = x.to(device, non_blocking=True)
        targets.extend(y.numpy())
        for i, model in enumerate(models_list):
            logits = model(x)  # (B, C)
            all_logits[i].append(logits.cpu())

    all_logits = [torch.cat(chunks, dim=0) for chunks in all_logits]
    return all_logits, np.array(targets)


def evaluate_partial_ensembles(args,
                               k: int,
                               all_logits: List[torch.Tensor],
                               all_targets: np.ndarray,
                               precisions: List[np.ndarray],
                               num_classes: int) -> Tuple[np.ndarray, np.ndarray, float]:
    subset_logits = all_logits[:k]
    n = subset_logits[0].size(0)
    final_score = torch.zeros(n, num_classes)

    if args.ensemble == "naive":
        for lg in subset_logits:
            final_score += F.softmax(lg, dim=1)

    elif args.ensemble == "model_prec":
        for lg, prec in zip(subset_logits, precisions[:k]):
            weight = float(np.mean(prec))
            final_score += F.softmax(lg, dim=1) * weight

    elif args.ensemble == "class_prec":
        for lg, prec in zip(subset_logits, precisions[:k]):
            prec_tensor = torch.tensor(prec)  # (C,)
            final_score += F.softmax(lg, dim=1) * prec_tensor

    elif args.ensemble == "classwise":
        prec_array = np.array(precisions[:k])              # (k, C)
        class_model_prec = torch.tensor(prec_array.T)      # (C, k)
        weights = torch.softmax(class_model_prec, dim=1)   # (C, k)

        final_logits = torch.zeros(n, num_classes)
        for j, lg in enumerate(subset_logits):
            w = weights[:, j]                              # (C,)
            final_logits += lg * w
        final_score = F.softmax(final_logits, dim=1)

    else:
        raise ValueError(f"Unknown ensemble type: {args.ensemble}")

    preds = final_score.argmax(dim=1).numpy()
    acc = accuracy_score(all_targets, preds)
    return all_targets, preds, acc


def predict_with_precision_weighting(args,
                                     models_list: List[torch.nn.Module],
                                     precisions: List[np.ndarray],
                                     dataloader: DataLoader,
                                     num_classes: int) -> Tuple[np.ndarray, np.ndarray, float]:
    device = next(models_list[0].parameters()).device
    prec_array = np.array(precisions)
    class_model_prec = torch.tensor(prec_array.T, device=device)  # (C, M)
    ensemble_weights = torch.softmax(class_model_prec, dim=1)     # (C, M)

    all_preds, all_targets = [], []

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            if args.ensemble == "classwise":
                logits_list = [m(x) for m in models_list]
                final_logits = torch.zeros(x.size(0), num_classes, device=device)
                for m_idx, lg in enumerate(logits_list):
                    w = ensemble_weights[:, m_idx]               # (C,)
                    final_logits += lg * w
                final_score = F.softmax(final_logits, dim=1)
            else:
                final_score = torch.zeros(x.size(0), num_classes, device=device)
                for m, prec in zip(models_list, precisions):
                    probs = F.softmax(m(x), dim=1)
                    if args.ensemble == "naive":
                        weighted = probs
                    elif args.ensemble == "model_prec":
                        weighted = probs * float(np.mean(prec))
                    elif args.ensemble == "class_prec":
                        weighted = probs * torch.tensor(prec, device=device)
                    else:
                        raise ValueError(f"Unknown ensemble type: {args.ensemble}")
                    final_score += weighted

            preds = final_score.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(y.cpu().numpy())

    y_true = np.array(all_targets)
    y_pred = np.array(all_preds)
    acc = accuracy_score(y_true, y_pred)
    return y_true, y_pred, acc


def report_and_save_metrics(y_true: np.ndarray,
                            y_pred: np.ndarray,
                            class_groups: dict,
                            num_classes: int,
                            args,
                            stage: int,
                            total: int,
                            model_acc: float) -> Tuple[float, float]:
    micro_acc = accuracy_score(y_true, y_pred)

    per_class_acc = []
    for c in range(num_classes):
        mask = (y_true == c)
        if np.any(mask):
            per_class_acc.append(np.mean(y_pred[mask] == c))
    macro_acc = float(np.mean(per_class_acc)) if per_class_acc else 0.0

    def group_macro(y_true_arr, y_pred_arr, group_labels):
        accs = []
        for c in group_labels:
            mask = (y_true_arr == c)
            if np.any(mask):
                accs.append(np.mean(y_pred_arr[mask] == c))
        return float(np.mean(accs)) if accs else 0.0

    many_acc = group_macro(y_true, y_pred, class_groups["many"])
    medium_acc = group_macro(y_true, y_pred, class_groups["medium"])
    few_acc = group_macro(y_true, y_pred, class_groups["few"])

    print(f"[Stage {stage}/{total}] Micro={micro_acc:.4f} | Macro={macro_acc:.4f} "
          f"| Many={many_acc:.4f} | Medium={medium_acc:.4f} | Few={few_acc:.4f} "
          f"| Ensemble Top1={model_acc:.4f}")

    results_dir = f"{args.dataset}_{args.results_out}"
    os.makedirs(results_dir, exist_ok=True)
    result_path = os.path.join(
        results_dir,
        f"[{_TIMESTAMP}]_{args.model}_{args.loss}_{args.subset}_s{stage}_t{total}_r{args.seed}_a{args.ensemble}.txt"
    )

    report = classification_report(y_true, y_pred, digits=4, zero_division=0)
    with open(result_path, "w") as f:
        f.write(f"Micro Accuracy (Overall): {micro_acc:.4f}\n")
        f.write(f"Macro Accuracy (Mean Per-Class): {macro_acc:.4f}\n")
        f.write(f"Many-shot accuracy: {many_acc:.4f}\n")
        f.write(f"Medium-shot accuracy: {medium_acc:.4f}\n")
        f.write(f"Few-shot accuracy: {few_acc:.4f}\n\n")
        f.write(report)
        f.write(f"\nCurrent Ensemble Top1 accuracy: {model_acc:.4f}\n\n")

    return micro_acc, macro_acc


def build_model(model_name: str, num_classes: int) -> torch.nn.Module:
    name = model_name.lower()
    if "resnet32" in name:
        return resnet32(num_classes=num_classes)
    if "wrn28" in name:
        return wrn28(num_classes=num_classes)
    if "vit" in name:
        return create_model("vit_base_patch16_224", pretrained=True, num_classes=num_classes)
    if "resnext50" in name:
        return models.resnext50_32x4d(weights=None)
    if "resnet152" in name:
        model = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V1)
        model.fc = CosineClassifier(model.fc.in_features, num_classes)
        return model
    raise ValueError(f"Unknown model type: {model_name}")


def train_one_stage(args,
                    model: torch.nn.Module,
                    train_loader: DataLoader,
                    class_counts: List[int],
                    class_precisions) -> Tuple[torch.nn.Module, str]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    name = args.model.lower()
    if "vit" in name:
        optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.05)
        scheduler = WarmupCosineAnnealingLR(optimizer, total_epochs=args.epochs, warmup_epochs=5, eta_min=0.0)
    elif "resnext50" in name:
        optimizer = torch.optim.SGD(model.parameters(), lr=0.025, momentum=0.9, weight_decay=5e-4, nesterov=True)
        scheduler = WarmupCosineAnnealingLR(optimizer, total_epochs=args.epochs, warmup_epochs=5, eta_min=0.0)
    elif "resnet152" in name:
        optimizer = torch.optim.SGD(
            [
                {"params": [p for n, p in model.named_parameters() if not n.startswith("fc.")], "lr": 0.001},
                {"params": list(model.fc.parameters()), "lr": 0.01},
            ],
            momentum=0.9,
            weight_decay=4e-4,
            nesterov=True,
        )
        warmup = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=5)
        main = LinearLR(optimizer, start_factor=1.0, end_factor=0.0, total_iters=args.epochs - 5)
        scheduler = ChainedScheduler([warmup, main])
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)
        scheduler = WarmupMultiStepLR(optimizer, milestones=[160, 180], gamma=0.1, warmup_epochs=5)

    if "bsm" in args.loss.lower():
        criterion = BalancedSoftmaxLoss(samples_per_class=class_counts)
    else:
        criterion = nn.CrossEntropyLoss()

    for epoch in range(args.epochs):
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        for x, y in train_loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True).long()

            optimizer.zero_grad(set_to_none=True)
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * y.size(0)
            correct += (outputs.argmax(dim=1) == y).sum().item()
            total += y.size(0)

        acc = 100.0 * correct / total if total > 0 else 0.0
        print(f"[Epoch {epoch+1:03d}/{args.epochs}] Loss={total_loss/total:.4f} | Acc={acc:.2f}%")
        scheduler.step()

    models_folder = f"{args.dataset}_{args.models_out}"
    os.makedirs(models_folder, exist_ok=True)
    ckpt_name = f"[{_TIMESTAMP}]_{args.model}_{args.loss}_c{max(class_counts)}_r{args.seed}.pth"
    model_path = os.path.join(models_folder, ckpt_name)
    torch.save(model.state_dict(), model_path)
    print(f"[INFO] Saved checkpoint: {model_path}")
    return model, model_path


def get_threshold(args) -> List[int]:
    csv_path = os.path.join("subsets", f"{args.dataset}.csv")
    df = pd.read_csv(csv_path)
    if args.subset not in df.columns:
        raise ValueError(f"Column '{args.subset}' not found in {csv_path}")
    thresholds = df[args.subset].dropna().astype(int).tolist()
    print(f"[INFO] Loaded {len(thresholds)} thresholds from '{args.subset}': {thresholds}")
    return thresholds


def run_training_pipeline(args) -> List[str]:
    set_seed(args.seed)
    train_loader = get_dataloader(args.dataset, args.model, train=True, shuffle=True, batch_size=args.batch_size)
    num_classes = train_loader.dataset.num_classes
    print(f"[INFO] Number of classes: {num_classes}")

    thresholds = get_threshold(args)

    model_paths: List[str] = []
    for stage, threshold in enumerate(thresholds, start=1):
        print(f"\n[Training {stage}] Threshold: {threshold}")
        subset = create_imbalanced_subset(train_loader.dataset, num_classes, threshold)
        class_counts = count_classes_in_subset(subset, num_classes)
        sampled_loader = DataLoader(subset, batch_size=args.batch_size, shuffle=True, num_workers=get_num_workers())

        model = build_model(args.model, num_classes)
        _, ckpt_path = train_one_stage(args, model, sampled_loader, class_counts, class_precisions=None)
        model_paths.append(ckpt_path)
    return model_paths


def get_class_precisions(conf_matrix: np.ndarray) -> np.ndarray:
    tp = np.diag(conf_matrix)
    fp = np.sum(conf_matrix, axis=0) - tp
    with np.errstate(divide="ignore", invalid="ignore"):
        prec = tp / (tp + fp)
        prec = np.nan_to_num(prec)
    return prec


def run_evaluation_pipeline(args, model_paths: List[str]) -> None:
    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader = get_dataloader(args.dataset, args.model, train=True, shuffle=True, batch_size=args.batch_size)
    test_loader = get_dataloader(args.dataset, args.model, train=False, shuffle=False, batch_size=args.batch_size)
    num_classes = train_loader.dataset.num_classes

    # If dataset exposes class_counts, prefer it; otherwise compute from full train set
    class_counts = getattr(train_loader.dataset, "class_counts", None)
    if class_counts is None:
        full_subset = Subset(train_loader.dataset, list(range(len(train_loader.dataset))))
        class_counts = count_classes_in_subset(full_subset, num_classes)
    class_groups = get_class_groups(class_counts)

    trained_models, precisions = [], []
    print("\n[INFO] Computing per-model class precision (train set)")
    for stage, model_path in enumerate(model_paths, start=1):
        print(f"[Model {stage}] Path: {model_path}")
        model = build_model(args.model, num_classes)
        sd = torch.load(model_path, map_location=device)
        model.load_state_dict(sd)
        model = model.to(device).eval()

        all_preds, all_labels = [], []
        with torch.no_grad():
            for images, labels in train_loader:
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                outputs = model(images)
                preds = outputs.argmax(1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        cm = confusion_matrix(all_labels, all_preds, labels=list(range(num_classes)))
        precision = get_class_precisions(cm)
        precisions.append(precision)
        trained_models.append(model)
        print(f"==> Avg Precision: {np.mean(precision) * 100:.2f}%")

    print("\n[INFO] Caching logits on test set")
    all_logits, all_targets = cache_logits(trained_models, test_loader, num_classes)

    for stage in range(1, len(model_paths) + 1):
        print(f"\n[Ensemble {stage}]")
        y_true, y_pred, top1_acc = evaluate_partial_ensembles(
            args, stage, all_logits, all_targets, precisions, num_classes
        )
        _ = report_and_save_metrics(
            y_true=y_true,
            y_pred=y_pred,
            class_groups=class_groups,
            num_classes=num_classes,
            args=args,
            stage=stage,
            total=len(model_paths),
            model_acc=top1_acc,
        )
        print(f"==> Ensemble Test Acc: {top1_acc * 100:.2f}%")


def main() -> None:
    parser = argparse.ArgumentParser(description="Train and evaluate ensembles on long-tailed datasets")
    parser.add_argument("--dataset", type=str, default="cifar100lt", help="Dataset name")
    parser.add_argument("--subset", type=str, default="decay_0.9", help="CSV column name for thresholds")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument("--model", type=str, default="resnet32", help="Backbone model")
    parser.add_argument("--epochs", type=int, default=180, help="Training epochs")
    parser.add_argument("--loss", type=str, default="bsm", help="Loss identifier for checkpoint naming")
    parser.add_argument("--models_out", type=str, default="saved_models", help="Directory to save checkpoints")
    parser.add_argument("--models_in", type=str, default=None, help="Directory to load checkpoints (skip training)")
    parser.add_argument("--results_out", type=str, default="results", help="Directory to save evaluation results")
    parser.add_argument("--seed", type=int, default=40, help="Random seed")
    parser.add_argument(
        "--ensemble",
        type=str,
        default="classwise",
        choices=["naive", "model_prec", "class_prec", "classwise"],
        help="Ensemble weighting strategy",
    )

    args = parser.parse_args()
    pprint.pprint(vars(args))

    if args.models_in:
        thresholds = get_threshold(args)
        model_paths = [str(p) for p in get_model_paths(args, thresholds)]
    else:
        model_paths = run_training_pipeline(args)

    run_evaluation_pipeline(args, model_paths)


if __name__ == "__main__":
    main()
