import os
import argparse
import torch
import torch.nn as nn
import pandas as pd
from torch import optim
from tqdm import tqdm

from config import get_config, list_configs
from utils import (
    GMM4PR,
    get_dataset, build_model, eval_acc,
    initialize_gmm_parameters,
    TemperatureScheduler, check_mode_collapse,
    build_decoder_from_flag
)


def main():
    parser = argparse.ArgumentParser(description="GMM4PR Training")

    parser.add_argument("--config", type=str, default="resnet18_on_cifar10",
                       help="Config name from config.py")
    parser.add_argument("--list-configs", action="store_true", default=False,
                       help="List all available configs and exit")

    parser.add_argument("--epochs", type=int, help="Override epochs")
    parser.add_argument("--K", type=int, help="Override number of components")
    parser.add_argument("--batch_size", type=int, help="Override batch size")
    parser.add_argument("--device", type=str, help="Override device")
    parser.add_argument("--clf_ckpt", type=str, help="Override classifier checkpoint path")

    args = parser.parse_args()

    if args.list_configs:
        list_configs()
        return

    cfg = get_config(args.config)

    if args.epochs is not None:
        cfg.epochs = args.epochs
    if args.K is not None:
        cfg.K = args.K
    if args.batch_size is not None:
        cfg.batch_size = args.batch_size
    if args.device is not None:
        cfg.device = args.device
    if args.clf_ckpt is not None:
        cfg.clf_ckpt = args.clf_ckpt

    print(cfg)

    device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
    print(f"\nUsing device: {device}")

    print(f"\nLoading dataset: {cfg.dataset}")
    dataset, num_classes, out_shape = get_dataset(cfg.dataset, cfg.data_root, train=True, resize=cfg.resize)
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True
    )
    print(f"Dataset: {len(dataset)} samples, {num_classes} classes, shape={out_shape}")

    print(f"\nLoading classifier: {cfg.arch}")
    model, feat_extractor = build_model(cfg.arch, num_classes, device)

    if not os.path.isfile(cfg.clf_ckpt):
        raise FileNotFoundError(f"Classifier not found: {cfg.clf_ckpt}")

    state = torch.load(cfg.clf_ckpt, map_location="cpu")
    state = state.get("state_dict", state.get("model_state", state))
    state = {k.replace("module.", ""): v for k, v in state.items()}

    model.load_state_dict(state, strict=False)
    model = model.to(device).eval()
    for p in model.parameters():
        p.requires_grad = False

    feat_extractor = feat_extractor.to(device).eval()
    for p in feat_extractor.parameters():
        p.requires_grad = False

    model_params = {id(p) for p in model.parameters()}
    feat_params  = {id(p) for p in feat_extractor.parameters()}
    shared = model_params & feat_params

    print(f"[check] model params: {len(model_params)}, feat_extractor params: {len(feat_params)}")
    if shared:
        print(f"[check] They share {len(shared)} parameters.")
    else:
        print("[check] No shared parameters.")

    print("Evaluating clean accuracy...")
    eval_acc(model, dataset, device)

    print(f"\nInitializing GMM: K={cfg.K}, D={cfg.latent_dim}, cond={cfg.cond_mode}")
    gmm = GMM4PR(
        K=cfg.K,
        latent_dim=cfg.latent_dim,
        device=device,
        T_pi=cfg.T_pi_init,
        T_mu=cfg.T_mu_init,
        T_sigma=cfg.T_sigma_init,
        T_shared=cfg.T_shared_init
    )

    if cfg.use_y_embedding:
        gmm.set_y_embedding(
            num_cls=num_classes,
            y_dim=cfg.y_emb_dim,
            normalize=cfg.y_emb_normalize
        )

    gmm.set_regularization(
        pi_entropy=cfg.reg_pi_entropy,
        mean_diversity=cfg.reg_mean_div,
    )

    feat_dim = None
    if cfg.cond_mode in ("x", "xy"):
        with torch.no_grad():
            x0, _, _ = next(iter(loader))
            feat_dim = feat_extractor(x0.to(device)).view(x0.size(0), -1).size(1)
        print(f"Feature dimension: {feat_dim}")

    gmm.set_condition(
        cond_mode=cfg.cond_mode,
        cov_type=cfg.cov_type,
        cov_rank=cfg.cov_rank,
        feat_dim=feat_dim or 0,
        num_cls=num_classes,
        hidden_dim=cfg.hidden_dim
    )

    if cfg.cond_mode in ("x", "xy"):
        gmm.set_feat_extractor(feat_extractor)

    if cfg.use_decoder:
        decoder = build_decoder_from_flag(
            cfg.decoder_backend,
            cfg.latent_dim,
            out_shape,
            device
        )
        gmm.set_up_sampler(decoder)

    gmm.set_budget(norm=cfg.norm, eps=cfg.epsilon)

    initialize_gmm_parameters(gmm, init_mode=cfg.init_mode)

    temp_scheduler = TemperatureScheduler(
        gmm,
        initial_T_pi=cfg.T_pi_init,
        final_T_pi=cfg.T_pi_final,

        initial_T_mu=cfg.T_mu_init,
        final_T_mu=cfg.T_mu_final,

        initial_T_sigma=cfg.T_sigma_init,
        final_T_sigma=cfg.T_sigma_final,

        initial_T_shared=cfg.T_shared_init,
        final_T_shared=cfg.T_shared_final,

        warmup_epochs=cfg.warmup_epochs
    )

    optimizer = optim.Adam(
        [p for p in gmm.parameters() if p.requires_grad],
        lr=cfg.lr,
        weight_decay=cfg.weight_decay
    )

    scheduler = None
    if cfg.use_lr_scheduler:
        from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

        warmup_scheduler = LinearLR(
            optimizer,
            start_factor=0.01,
            end_factor=1.0,
            total_iters=cfg.lr_warmup_epochs
        )

        cosine_scheduler = CosineAnnealingLR(
            optimizer,
            T_max=cfg.epochs - cfg.lr_warmup_epochs,
            eta_min=cfg.lr_min
        )

        scheduler = SequentialLR(
            optimizer,
            schedulers=[warmup_scheduler, cosine_scheduler],
            milestones=[cfg.lr_warmup_epochs]
        )

        print(f"\nLearning rate scheduler enabled:")
        print(f"  Warmup epochs: {cfg.lr_warmup_epochs}")
        print(f"  Initial LR: {cfg.lr}")
        print(f"  Min LR: {cfg.lr_min}")

    os.makedirs(cfg.ckp_dir, exist_ok=True)
    collapse_log = []
    loss_hist = {"epoch": [],
                 "loss": [],
                 "main_loss": [],
                 "reg_loss": [],
                 "pr": [],
                 "learning_rate": []
                 }

    gmm.train()
    print(f"\n{'='*60}")
    print(f"Starting training: {cfg.epochs} epochs")
    print(f"{'='*60}\n")

    for epoch in range(1, cfg.epochs + 1):
        T_pi, T_mu, T_sigma, T_shared = temp_scheduler.step(epoch)

        if hasattr(cfg, 'use_gumbel_anneal') and cfg.use_gumbel_anneal:
            alpha = (epoch - 1) / (cfg.epochs - 1)
            gumbel_temp = cfg.gumbel_temp_init + alpha * (cfg.gumbel_temp_final - cfg.gumbel_temp_init)
        else:
            gumbel_temp = cfg.gumbel_temp_final

        pbar = tqdm(loader, desc=f"Epoch {epoch}/{cfg.epochs} [norm={cfg.norm}, eps={cfg.epsilon:.3f}]")

        epoch_loss = 0.0
        epoch_main = 0.0
        epoch_reg = 0.0
        epoch_pr = 0.0
        epoch_pr_count = 0
        total_samples = 0
        num_processed_batches = 0

        acc_counter = 0
        optimizer.zero_grad(set_to_none=True)

        for batch_idx, (x, y, _) in enumerate(pbar):
            if batch_idx >= cfg.batch_index_max:
                break
            x, y = x.to(device), y.to(device)

            with torch.no_grad():
                model.eval()
                pred = model(x).argmax(1)
                mask = (pred == y).tolist()
                if sum(mask) == 0:
                    continue
            x_clean = x[mask]
            y_clean = y[mask]
            total_samples += len(y_clean)
            num_processed_batches += 1

            return_details = (num_processed_batches == 1 and epoch % cfg.check_collapse_every == 0)

            out = gmm.pr_loss(
                x_clean, y_clean, model,
                num_samples=cfg.num_samples,
                loss_variant=cfg.loss_variant, kappa=cfg.kappa,
                chunk_size=cfg.chunk_size,
                return_reg_details=return_details,
                gumbel_temperature=gumbel_temp
            )

            loss = out["loss"] / cfg.accumulate_grad
            loss.backward()
            acc_counter += 1

            if acc_counter % cfg.accumulate_grad == 0:
                if cfg.grad_clip > 0:
                    nn.utils.clip_grad_norm_(gmm.parameters(), cfg.grad_clip)
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

            epoch_loss += out["loss"].item()
            epoch_main += out["main"].item()
            epoch_reg += out["reg"].item()
            epoch_pr += out["pr"] * len(y_clean)
            epoch_pr_count += len(y_clean)

            if return_details and 'reg_details' in out:
                print(f"\n[Epoch {epoch}] Regularization details:")
                for k, v in out['reg_details'].items():
                    print(f"  {k:20s}: {v:.6f}")
                print(f"  π distribution: {out['pi_probs'].cpu().numpy()}")

            pbar.set_postfix({
                "loss": f"{out['loss'].item():.4e}",
                "main": f"{out['main'].item():.4e}",
                "reg": f"{out['reg'].item():.4e}",
            })

        if acc_counter % cfg.accumulate_grad != 0:
            if cfg.grad_clip > 0:
                nn.utils.clip_grad_norm_(gmm.parameters(), cfg.grad_clip)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

        avg_loss = epoch_loss / max(num_processed_batches, 1)
        avg_main = epoch_main / max(num_processed_batches, 1)
        avg_reg = epoch_reg / max(num_processed_batches, 1)
        avg_pr = epoch_pr / max(epoch_pr_count, 1)

        current_lr = optimizer.param_groups[0]['lr']

        loss_hist["epoch"].append(epoch)
        loss_hist["loss"].append(avg_loss)
        loss_hist["main_loss"].append(avg_main)
        loss_hist["reg_loss"].append(avg_reg)
        loss_hist["pr"].append(avg_pr)
        loss_hist["learning_rate"].append(current_lr)

        print(f"\nEpoch {epoch} Summary:")
        print(f"  Loss: {avg_loss:.4f} (main={avg_main:.4f}, reg={avg_reg:.4f})")
        print(f"  Learning Rate: {current_lr:.6f}")
        print(f"  Batches: {num_processed_batches}/{len(loader)} processed")
        print(f"  Samples used: {total_samples}/{len(dataset)}")
        print(f"  Temperatures: T_pi={T_pi:.2f}, T_mu={T_mu:.2f}, T_sigma={T_sigma:.2f}, T_shared={T_shared:.2f}, T_gumbel={gumbel_temp:.2f}")

        if scheduler is not None:
            scheduler.step()

        if epoch % cfg.check_collapse_every == 0:
            stats = check_mode_collapse(gmm, loader, device)
            collapse_log.append({
                'epoch': epoch,
                'max_pi': stats['max_pi'],
                'min_pi': stats['min_pi'],
                'std_pi': stats['std_pi'],
                'entropy_ratio': stats['entropy_ratio'],
                'T_pi': T_pi,
                'T_gumbel': gumbel_temp,
                'avg_loss': avg_loss
            })

    print(f"\n{'='*60}")
    print("Training complete! Saving model...")
    print(f"{'='*60}")

    save_dir = f"{cfg.ckp_dir}/{cfg.arch}_on_{cfg.dataset}/"
    os.makedirs(save_dir, exist_ok=True)

    pd.DataFrame(loss_hist).to_csv(os.path.join(save_dir, f"loss_hist_{cfg.exp_name}.csv"), index=False)
    print(f"[save] loss history -> {save_dir}/loss_hist_{cfg.exp_name}.csv")

    save_path = os.path.join(save_dir, f"gmm_{cfg.exp_name}.pt")
    gmm.save(
        save_path,
        extra={
            "config": cfg.to_dict(),
            "final_gumbel_temperature": gumbel_temp,
        }
    )
    print(f"✓ Model saved: {save_path}")

    if collapse_log:
        df = pd.DataFrame(collapse_log)
        log_path = os.path.join(save_dir, f"collapse_log_{cfg.exp_name}.csv")
        df.to_csv(log_path, index=False)
        print(f"✓ Collapse log saved: {log_path}")

    print(f"\n{'='*60}")
    print("DONE!")
    print(f"{'='*60}\n")


if __name__ == "__main__":
    main()