import os
import sys
import time
import random
import argparse
from os.path import dirname, abspath, join

import numpy as np
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

# Add project root to path for safe imports when running directly
script_dir = dirname(abspath(__file__))
project_root = dirname(script_dir)
if project_root not in sys.path:
    sys.path.insert(0, project_root)
os.chdir(project_root)

from model.LSM import LSOModel
from tools.dataset_make_x8 import DatasetFromHdf5
from tools.metrics import PSNR


def set_random_seed(seed: int) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True


def build_model(opt) -> torch.nn.Module:
    return LSOModel(
        args=opt,
        bilinear=True,
        n_select_bands=opt.n_bands_rgb,
        n_bands=opt.n_bands,
        sf=opt.sf,
    ).cuda()


def count_parameters(model: torch.nn.Module) -> tuple:
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total


def save_checkpoint(model, epoch: int, exp_dir: str) -> None:
    os.makedirs(exp_dir, exist_ok=True)
    model_out_path = join(exp_dir, f"model_epoch_{epoch}.pth.tar")
    state = {"epoch": epoch, "model": model}
    torch.save(state, model_out_path)
    print(f"Checkpoint saved to {model_out_path}")


def save_best_checkpoint(model, epoch: int, exp_dir: str) -> None:
    os.makedirs(exp_dir, exist_ok=True)
    model_out_path = join(exp_dir, "model_best.pth.tar")
    temp_path = model_out_path + ".tmp"
    torch.save({"epoch": epoch, "model": model}, temp_path)
    os.replace(temp_path, model_out_path)
    print(f"Best checkpoint saved to {model_out_path}")


def parse_args():
    parser = argparse.ArgumentParser(
        description="Blind-review friendly training script (no hardcoded paths or dataset names)."
    )
    parser.add_argument("--train_h5", type=str, required=True, help="Path to training .h5 file")
    parser.add_argument("--val_h5", type=str, required=True, help="Path to validation .h5 file")
    parser.add_argument(
        "--dataset_type",
        type=str,
        default="x4",
        help=(
            "Downsample factor tag used by DatasetFromHdf5 to derive LRHSI. "
            "Examples: cave_x4, harvard_x4. If not recognized, LRHSI must exist in h5."
        ),
    )
    parser.add_argument("--model", type=str, default="lsm_x4", help="Kept for compatibility")
    parser.add_argument("--sf", type=int, default=4, help="Scale factor for the model")
    parser.add_argument("--image_size", type=int, default=64)
    parser.add_argument("--n_bands", type=int, default=31)
    parser.add_argument("--n_bands_rgb", type=int, default=3)

    parser.add_argument("--batchSize", type=int, default=16)
    parser.add_argument("--val_batchSize", type=int, default=10)
    parser.add_argument("--n_epochs", type=int, default=1000)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--seed", type=int, default=42)

    parser.add_argument("--num_token", type=int, default=8)
    parser.add_argument("--num_basis", type=int, default=24)
    parser.add_argument("--width", type=int, default=64)
    parser.add_argument("--patch_size", type=int, default=4)
    parser.add_argument("--guide_dim", type=int, default=128)
    parser.add_argument("--n_resblocks", type=int, default=4)

    parser.add_argument("--num_workers_train", type=int, default=4)
    parser.add_argument("--num_workers_val", type=int, default=2)
    parser.add_argument("--exp_dir", type=str, default="experiments/run1")
    parser.add_argument("--val_interval", type=int, default=10)
    parser.add_argument("--val_after_epoch", type=int, default=600)
    parser.add_argument("--val_interval_after", type=int, default=1)
    parser.add_argument("--save_best", type=int, default=1)
    return parser.parse_args()


def main():
    opt = parse_args()
    set_random_seed(opt.seed)

    print("===> Loading datasets")
    train_set = DatasetFromHdf5(opt.train_h5, opt.dataset_type)
    val_set = DatasetFromHdf5(opt.val_h5, opt.dataset_type)

    training_data_loader = DataLoader(
        dataset=train_set,
        batch_size=opt.batchSize,
        shuffle=True,
        num_workers=opt.num_workers_train,
        pin_memory=True,
        persistent_workers=True,
    )
    val_data_loader = DataLoader(
        dataset=val_set,
        batch_size=opt.val_batchSize,
        shuffle=False,
        num_workers=opt.num_workers_val,
        pin_memory=True,
    )

    model = build_model(opt)
    trainable_params, total_params = count_parameters(model)
    print(f"Trainable params: {trainable_params:,} ({trainable_params / 1e6:.2f}M)")
    print(f"Total params: {total_params:,} ({total_params / 1e6:.2f}M)")

    loss_fn = torch.nn.L1Loss().cuda()
    optimizer = optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=300, gamma=0.6)

    best_psnr = -1.0
    for epoch in range(1, opt.n_epochs + 1):
        model.train()
        print(f"Train_Epoch_{epoch}: lr={optimizer.param_groups[0]['lr']}")
        for _, batch in enumerate(training_data_loader, 1):
            input_rgb = batch[0].cuda(non_blocking=True)
            ms = batch[1].cuda(non_blocking=True)
            ref = batch[2].cuda(non_blocking=True)

            out = model(input_rgb, ms)
            loss = loss_fn(out, ref)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()

        val_interval = opt.val_interval_after if epoch > opt.val_after_epoch else opt.val_interval
        if epoch % val_interval == 0:
            model.eval()
            with torch.no_grad():
                psnr_list = []
                for _, batch in enumerate(val_data_loader):
                    input_rgb = batch[0].cuda(non_blocking=True)
                    ms = batch[1].cuda(non_blocking=True)
                    ref = batch[2].cuda(non_blocking=True)
                    out = model(input_rgb, ms).clamp(0.0, 1.0)

                    ref_np = ref[0].permute(1, 2, 0).cpu().numpy()
                    output_np = out[0].permute(1, 2, 0).cpu().numpy()
                    psnr = PSNR(ref_np, output_np)[0]
                    psnr_list.append(psnr)

            current_psnr = float(np.mean(psnr_list))
            print(f"PSNR: {current_psnr:.4f}")
            if current_psnr > best_psnr:
                best_psnr = current_psnr
                if int(opt.save_best) == 1:
                    print(f"PSNR improved to {best_psnr:.4f}, saving best checkpoint...")
                    save_best_checkpoint(model, epoch, opt.exp_dir)

    print(f"Best PSNR: {best_psnr:.4f}")


if __name__ == "__main__":
    main()
