import argparse
import os
import random
from typing import Any, Dict, List, Tuple

import numpy as np
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from experiments.deployment_gate.model_factory import get_resnet50


def _set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def evaluate(model: nn.Module, data_loader: DataLoader, device: str) -> Dict[str, float]:
    # Unwrap DataParallel if present and ensure correct device
    if isinstance(model, nn.DataParallel):
        model = model.module
    model = model.to(device)
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)
            logits = model(images)
            loss = criterion(logits, labels)
            loss_sum += loss.item() * images.size(0)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    avg_loss = loss_sum / max(total, 1)
    acc = correct / max(total, 1)
    return {"loss": avg_loss, "acc": acc}


def _train_one_epoch(model: nn.Module, loader: DataLoader, optimizer: optim.Optimizer, device: str) -> float:
    model.train()
    criterion = nn.CrossEntropyLoss()
    running_loss = 0.0
    count = 0
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        count += images.size(0)
    return running_loss / max(count, 1)


def train_source_model(config: Dict[str, Any], loaders: Dict[str, DataLoader]) -> Tuple[nn.Module, str]:
    """
    Train a source model Q on source_train, select best on source_val.
    Returns (best_model, best_checkpoint_path).
    """
    device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
    _set_seed(int(config.get("seed", 42)))

    num_classes = int(config.get("num_classes", 65))
    learning_rate = float(config.get("learning_rate", 1e-4))
    weight_decay = float(config.get("weight_decay", 5e-4))
    max_epochs = int(config.get("max_epochs", 20))

    output_dir = config.get("output_dir", "experiments/deployment_gate/results")
    ckpt_dir = os.path.join(output_dir, "checkpoints", "source")
    os.makedirs(ckpt_dir, exist_ok=True)
    best_ckpt_path = os.path.join(ckpt_dir, "source_best.pth")

    model = get_resnet50(pretrained=True, num_classes=num_classes).to(device)
    # Optional multi-GPU
    if device.startswith("cuda") and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    best_val_acc = -1.0
    for epoch in range(1, max_epochs + 1):
        train_loss = _train_one_epoch(model, loaders["source_train"], optimizer, device)
        val_metrics = evaluate(model, loaders["source_val"], device)
        val_acc = val_metrics["acc"]
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # Handle DP state dict transparently
            state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(state, best_ckpt_path)
        # Optional: print minimal progress
        print(f"[Source][Epoch {epoch}/{max_epochs}] train_loss={train_loss:.4f} val_acc={val_acc:.4f}")

    # Load best
    # Load back into the (possibly DataParallel) model
    state = torch.load(best_ckpt_path, map_location=device)
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(state)
    else:
        model.load_state_dict(state)
    
    # Return the unwrapped model to prevent double-wrapping in main.py
    unwrapped_model = model.module if isinstance(model, nn.DataParallel) else model
    return unwrapped_model, best_ckpt_path


def train_candidate_models(source_model_path: str, config: Dict[str, Any], loaders: Dict[str, DataLoader]) -> List[str]:
    """
    Fine-tune N candidate models from source_model_q on target_train.
    Returns a list of checkpoint paths for candidates.
    """
    device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
    _set_seed(int(config.get("seed", 42)))

    num_classes = int(config.get("num_classes", 65))
    base_lr = float(config.get("learning_rate", 1e-4))
    weight_decay = float(config.get("weight_decay", 5e-4))

    num_candidates = int(config.get("num_candidates", 20))
    include_identity = bool(config.get("include_identity_candidate", False))
    lr_choices = list(config.get("candidate_lr_choices", [1e-4, 5e-5, 1e-5]))
    epoch_choices = list(config.get("candidate_epoch_choices", [3, 4, 5]))

    output_dir = config.get("output_dir", "experiments/deployment_gate/results")
    ckpt_dir = os.path.join(output_dir, "checkpoints", "candidates")
    os.makedirs(ckpt_dir, exist_ok=True)

    candidate_paths: List[str] = []
    next_idx = 1
    if include_identity and num_candidates >= 1:
        # Create an identity candidate that is exactly the source checkpoint (no fine-tuning)
        identity_ckpt = os.path.join(ckpt_dir, f"tilde_q_{next_idx}.pth")
        shutil.copyfile(source_model_path, identity_ckpt)
        candidate_paths.append(identity_ckpt)
        next_idx += 1

    for idx in range(next_idx, num_candidates + 1):
        lr = float(random.choice(lr_choices)) if lr_choices else base_lr
        epochs = int(random.choice(epoch_choices)) if epoch_choices else 3

        # Fresh model initialized from source checkpoint
        model = get_resnet50(pretrained=False, num_classes=num_classes).to(device)
        if device.startswith("cuda") and torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        state = torch.load(source_model_path, map_location=device)
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(state)
        else:
            model.load_state_dict(state)

        # Optionally: different augmentation strengths could be added by swapping loaders/transforms here
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

        for e in range(1, epochs + 1):
            train_loss = _train_one_epoch(model, loaders["target_train"], optimizer, device)
            # Lightweight eval optional; we skip for speed
            print(f"[Candidate {idx}][{e}/{epochs}] lr={lr} train_loss={train_loss:.4f}")

        ckpt_path = os.path.join(ckpt_dir, f"tilde_q_{idx}.pth")
        state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
        torch.save(state, ckpt_path)
        candidate_paths.append(ckpt_path)

    return candidate_paths


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Training sanity check")
    parser.add_argument("--config", type=str, default="configs/office_home.yaml")
    args = parser.parse_args()
    print("This module provides training functions for source and candidates.")
