import argparse
import json
import numpy as np
import time
import os
import torch
import torch.nn as nn
import torch.nn.functional as F  # [KD] 需要
import torch.optim as optim
from torch.utils import data
import torch.backends.cudnn as cudnn
from utils.tools import adjust_learning_rate, label_accuracy_score, eval_image
from dataset.Sen2Fire_Dataset import Sen2FireDataSet
from ptflops import get_model_complexity_info

from models.FCN import FCN8s
from models.SegNet import SegNet
from models.unet import unet
import random
from tqdm import tqdm

name_classes = np.array(['non-fire','fire'], dtype=str)
epsilon = 1e-14

class F1Loss(nn.Module):
    def __init__(self, epsilon=1e-7):
        super(F1Loss, self).__init__()
        self.epsilon = epsilon

    def forward(self, logits, targets):
        num_classes = logits.size(1)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()

        probs = torch.softmax(logits, dim=1)
        probs = probs.view(logits.shape[0], num_classes, -1)
        targets_onehot = targets_onehot.view(targets.shape[0], num_classes, -1)

        TP = (probs * targets_onehot).sum(dim=2)
        FP = (probs * (1 - targets_onehot)).sum(dim=2)
        FN = ((1 - probs) * targets_onehot).sum(dim=2)

        precision = TP / (TP + FP + self.epsilon)
        recall = TP / (TP + FN + self.epsilon)
        f1 = 2 * precision * recall / (precision + recall + self.epsilon)

        return 1 - f1.mean()

# =========================
# [KD] 蒸馏损失（Soft Label KL）
# =========================
class DistillationLoss(nn.Module):
    def __init__(self, temperature: float = 4.0):
        super().__init__()
        self.T = temperature

    def forward(self, student_logits, teacher_logits):
        # 直接对 (B, C, H, W) 做 KL，dim=1 是类别维
        T = self.T
        s = F.log_softmax(student_logits / T, dim=1)
        t = F.softmax(teacher_logits / T, dim=1)
        # 按照 Hinton 公式，乘以 T^2
        return F.kl_div(s, t, reduction="batchmean") * (T * T)

def init_seeds(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def get_arguments():
    parser = argparse.ArgumentParser()

    # Model
    parser.add_argument("--model", type=str, default='unet',
                        help="model name.")
    parser.add_argument("--num_layer", type=int, default=64,
                        help="basic layer number (student).")
    parser.add_argument("--LAM", action='store_true',
                        help="whether apply LEO attention map.")
    parser.add_argument("--num_layers_lam", type=int, default=1,
                        help="number of LAM layer.")
    parser.add_argument("--num_components_lam", type=int, default=3,
                        help="number of LAM components.")

    # ===== [KD] 新增：教师模型相关 =====
    parser.add_argument("--teacher_model", type=str, default="",
                        help="path to teacher .pth (required for distillation).")
    parser.add_argument("--teacher_num_layer", type=int, default=48,
                        help="teacher basic layer number (e.g., 48).")

    # ===== [KD] 新增：蒸馏系数 =====
    parser.add_argument("--kd_weight", type=float, default=1.0,
                        help="weight for KD loss. 0 to disable KD.")
    parser.add_argument("--kd_temperature", type=float, default=4.0,
                        help="temperature for KD.")

    # Dataset
    parser.add_argument("--data_dir", type=str, default='./dataset/Sen2Fire/',
                        help="dataset path.")
    parser.add_argument("--train_list", type=str, default='./dataset/train.txt',
                        help="training list file.")
    parser.add_argument("--val_list", type=str, default='./dataset/val.txt',
                        help="val list file.")
    parser.add_argument("--test_list", type=str, default='./dataset/test.txt',
                        help="test list file.")
    parser.add_argument("--num_classes", type=int, default=2,
                        help="number of classes.")
    parser.add_argument("--mode", type=int, default=5,
                        help="input type (0-all_bands, 1-all_bands_aerosol,...).")
    parser.add_argument("--batch_size", type=int, default=8,
                        help="number of images in each batch.")
    parser.add_argument("--num_workers", type=int, default=1,
                        help="number of workers for dataloader.")

    # Training
    parser.add_argument("--learning_rate", type=float, default=1e-4,
                        help="base learning rate.")
    parser.add_argument("--weight_decay", type=float, default=5e-4,
                        help="regularisation parameter for L2-loss.")
    parser.add_argument("--weight", type=float, default=10,
                        help="ce weight.")
    parser.add_argument("--num_steps", type=int, default=10000,
                        help="number of training steps.")
    parser.add_argument("--num_steps_stop", type=int, default=10000,
                        help="number of training steps for early stopping.")
    parser.add_argument("--patience", type=int, default=3,
                        help="patience for early stopping (number of validations to wait without improvement).")
    parser.add_argument("--lamda_F1loss", type=float, default=1,
                        help="lamda_F1loss.")

    # Result
    parser.add_argument("--snapshot_dir", type=str, default='./Exp/',
                        help="where to save snapshots of the model.")

    return parser.parse_args()

modename = ['all_bands',                        #0
            'all_bands_aerosol',                #1
            'rgb',                              #2
            'rgb_aerosol',                      #3
            'swir',                              #4
            'swir_aerosol',                     #5
            'nbr',                              #6
            'nbr_aerosol',                      #7
            'ndvi',                             #8
            'ndvi_aerosol',                     #9
            'rgb_swir_nbr_ndvi',                #10
            'rgb_swir_nbr_ndvi_aerosol',]       #11

input_dim_lookup = {i: (12 if i == 0 else 13 if i == 1 else 3 if i % 2 == 0 else 4) for i in range(10)}
input_dim_lookup[10] = 6
input_dim_lookup[11] = 7

def build_model(model_name, n_classes, n_channels, layer_num, LAM, num_layers_lam, num_components_lam):
    if model_name == "unet":
        return unet(n_classes=n_classes, n_channels=n_channels, layer_num=layer_num,
                    LAM=LAM, num_layers_lam=num_layers_lam, num_components_lam=num_components_lam)
    elif model_name == "fcn8s":
        return FCN8s(n_classes=n_classes, n_channels=n_channels, layer_num=layer_num,
                    LAM=LAM, num_layers_lam=num_layers_lam, num_components_lam=num_components_lam)
    elif model_name == "segnet":
        return SegNet(n_classes=n_classes, n_channels=n_channels, layer_num=layer_num,
                      LAM=LAM, num_layers_lam=num_layers_lam, num_components_lam=num_components_lam)
    else:
        raise ValueError(f"Unsupported model type: {model_name}")

def main():
    args = get_arguments()
    snapshot_dir = os.path.join(
        args.snapshot_dir,
        'input_' + modename[args.mode],
        (args.model.lower() + '_LAM') if args.LAM else args.model.lower(),
        'layer_num_' + str(args.num_layer),
        # [KD] 将 KD 配置也写入路径，方便区分实验
        f'KD_{args.kd_weight}_T{args.kd_temperature}' if args.kd_weight > 0 else 'NoKD',
        'weight_' + str(args.weight) + '_time' + time.strftime('%m%d_%H%M', time.localtime(time.time()))
    ) + '/'

    if not os.path.exists(snapshot_dir):
        os.makedirs(snapshot_dir)

    with open(os.path.join(snapshot_dir, 'args.json'), 'w', encoding='utf-8') as f:
        json.dump(vars(args), f, indent=2, ensure_ascii=False)

    f = open(snapshot_dir+'Training_log.txt', 'w')

    cudnn.enabled = True
    cudnn.benchmark = True
    init_seeds()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create student network
    input_dim = input_dim_lookup[args.mode]
    model_name = args.model.lower() + '_LAM' if args.LAM else args.model.lower()
    base_name = args.model.lower()  # 'unet' / 'fcn8s' / 'segnet'

    student = build_model(base_name, args.num_classes, input_dim, args.num_layer,
                          args.LAM, args.num_layers_lam, args.num_components_lam)

    # [KD] Create teacher network（同架构，layer_num=teacher_num_layer）
    base_snapshot_dir = os.path.join(args.snapshot_dir,
                                        'input_' + modename[args.mode],
                                        args.model.lower(),
                                        'layer_num_' + str(args.teacher_num_layer))
    teacher_path = os.path.join(base_snapshot_dir, args.teacher_model, 'best_model.pth')

    teacher = None
    if args.kd_weight > 0:
        assert teacher_path != "", "When kd_weight>0, --teacher_path must be provided."
        teacher = build_model(base_name, args.num_classes, input_dim, args.teacher_num_layer,
                              args.LAM, args.num_layers_lam, args.num_components_lam)
        state = torch.load(teacher_path, map_location="cpu")
        teacher.load_state_dict(state, strict=True)
        teacher.to(device)
        teacher.eval()
        for p in teacher.parameters():
            p.requires_grad = False

    # 复杂度统计对 student
    input_size_train = (512, 512)
    input_res = (input_dim, input_size_train[1], input_size_train[0])
    macs, params = get_model_complexity_info(student, input_res, as_strings=False, print_per_layer_stat=False, verbose=False)
    FLOPS = macs * 2
    model_complexity_info = {'GFLOP': FLOPS/(1e9), 'parameter_GB': params/(1024**3)}
    np.savez(os.path.join(snapshot_dir, 'model_complexity_info.npz'), **model_complexity_info)

    print(f"[INFO] Total FLOP (student): {FLOPS/(1e9)}, GFLOPs")
    print(f"[INFO] Total trainable parameters (student): {params/(1024**3)}, GB")
    print(f"[INFO] Mode: {modename[args.mode]}")
    print(f"[INFO] Start training {model_name}")
    if teacher is not None:
        print(f"[KD] Teacher loaded from: {args.teacher_model} | teacher_num_layer={args.teacher_num_layer} | T={args.kd_temperature} | kd_weight={args.kd_weight}")

    f.write(f"[INFO] Total FLOP (student): {FLOPS/(1e9)}, GFLOPs\n")
    f.write(f"[INFO] Total trainable parameters (student): {params/(1024**3)} GB\n")
    f.write(f"[INFO] Mode: {modename[args.mode]}\n")
    f.write(f"[INFO] Start training {model_name}\n")
    if teacher is not None:
        f.write(f"[KD] Teacher: {args.teacher_model} | teacher_num_layer={args.teacher_num_layer} | T={args.kd_temperature} | kd_weight={args.kd_weight}\n")

    student.train()
    student = student.to(device)

    train_loader = data.DataLoader(
        Sen2FireDataSet(args.data_dir, args.train_list, max_iters=args.num_steps_stop*args.batch_size, mode=args.mode),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True
    )

    val_loader = data.DataLoader(
        Sen2FireDataSet(args.data_dir, args.val_list, mode=args.mode),
        batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True
    )

    test_loader = data.DataLoader(
        Sen2FireDataSet(args.data_dir, args.test_list, mode=args.mode),
        batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True
    )

    optimizer = optim.Adam(student.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    # interpolation for the probability maps and labels
    interp = nn.Upsample(size=(input_size_train[1], input_size_train[0]), mode='bilinear')

    F1_best = -1
    patience_counter = 0
    early_stop_triggered = False

    class_weights = [1, args.weight]
    L_ce = nn.CrossEntropyLoss(weight=torch.Tensor(class_weights).to(device))
    L_f1 = F1Loss()
    L_kd = DistillationLoss(temperature=args.kd_temperature) if args.kd_weight > 0 else None

    # ============== TRAINING ==============
    for batch_index, train_data in enumerate(train_loader):
        if batch_index == args.num_steps_stop or early_stop_triggered:
            break
        tem_time = time.time()
        adjust_learning_rate(optimizer, args.learning_rate, batch_index, args.num_steps)
        student.train()
        optimizer.zero_grad()

        patches, labels, _, _ = train_data
        patches = patches.to(device)
        labels = labels.to(device).long()

        # forward
        student_logits = student(patches)
        student_logits = interp(student_logits)  # 对齐到 512x512

        # segmentation loss（原有）
        lamda = args.lamda_F1loss
        L_seg_value = (1 - lamda) * L_ce(student_logits, labels) + lamda * L_f1(student_logits, labels)

        # [KD] 蒸馏项
        if L_kd is not None:
            with torch.no_grad():
                teacher_logits = teacher(patches)          # (B,C,h,w)
                teacher_logits = interp(teacher_logits)    # 对齐到 512x512
            L_kd_value = L_kd(student_logits, teacher_logits)
            L_total = L_seg_value + args.kd_weight * L_kd_value
        else:
            L_kd_value = torch.tensor(0.0, device=device)
            L_total = L_seg_value

        # backward
        L_total.backward()
        optimizer.step()

        # metrics（不变）
        with torch.no_grad():
            _, predict_labels = torch.max(student_logits, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = labels.detach().cpu().numpy()
            metrics_batch = []
            for lt, lp in zip(lbl_true, lbl_pred):
                _, _, mean_iu, _ = label_accuracy_score(lt, lp, n_class=args.num_classes)
                metrics_batch.append(mean_iu)
            batch_miou = np.nanmean(metrics_batch, axis=0)
            batch_oa = np.sum(lbl_pred == lbl_true) * 1.0 / len(lbl_true.reshape(-1))

        t = time.time() - tem_time

        if (batch_index + 1) % 10 == 0:
            print('Iter %d/%d Time: %.2f Batch_OA = %.2f Batch_mIoU = %.2f CE+F1 = %.3f KD = %.3f Total = %.3f'
                  % (batch_index + 1, args.num_steps, t, batch_oa*100, batch_miou*100,
                     L_seg_value.item(), float(L_kd_value.item()), L_total.item()))
            f.write('Iter %d/%d Time: %.2f Batch_OA = %.2f Batch_mIoU = %.2f CE+F1 = %.3f KD = %.3f Total = %.3f\n'
                    % (batch_index + 1, args.num_steps, t, batch_oa*100, batch_miou*100,
                       L_seg_value.item(), float(L_kd_value.item()), L_total.item()))
            f.flush()

        # ======= 验证（每 500 iter）======
        if (batch_index + 1) % 500 == 0:
            print('Validating..........')
            f.write('Validating..........\n')

            student.eval()
            TP_all = np.zeros((args.num_classes, 1))
            FP_all = np.zeros((args.num_classes, 1))
            TN_all = np.zeros((args.num_classes, 1))
            FN_all = np.zeros((args.num_classes, 1))
            n_valid_sample_all = 0
            F1 = np.zeros((args.num_classes, 1))
            IoU = np.zeros((args.num_classes, 1))

            tbar = tqdm(val_loader)
            for _, batch in enumerate(tbar):
                image, label, _, _ = batch
                label = label.squeeze().numpy()
                image = image.float().to(device)

                with torch.no_grad():
                    pred = student(image)

                _, pred = torch.max(interp(nn.functional.softmax(pred, dim=1)).detach(), 1)
                pred = pred.squeeze().data.cpu().numpy()

                TP, FP, TN, FN, n_valid_sample = eval_image(pred.reshape(-1), label.reshape(-1), args.num_classes)
                TP_all += TP
                FP_all += FP
                TN_all += TN
                FN_all += FN
                n_valid_sample_all += n_valid_sample

            OA = np.sum(TP_all)*1.0 / n_valid_sample_all
            for i in range(args.num_classes):
                P = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + epsilon)
                R = TP_all[i]*1.0 / (TP_all[i] + FN_all[i] + epsilon)
                F1[i] = 2.0*P*R / (P + R + epsilon)
                IoU[i] = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + FN_all[i] + epsilon)

                if i == 1:
                    print('===>' + name_classes[i] + ' Precision: %.2f' % (P * 100))
                    print('===>' + name_classes[i] + ' Recall: %.2f' % (R * 100))
                    print('===>' + name_classes[i] + ' IoU: %.2f' % (IoU[i] * 100))
                    print('===>' + name_classes[i] + ' F1: %.2f' % (F1[i] * 100))
                    f.write('===>' + name_classes[i] + ' Precision: %.2f\n' % (P * 100))
                    f.write('===>' + name_classes[i] + ' Recall: %.2f\n' % (R * 100))
                    f.write('===>' + name_classes[i] + ' IoU: %.2f\n' % (IoU[i] * 100))
                    f.write('===>' + name_classes[i] + ' F1: %.2f\n' % (F1[i] * 100))

            mF1 = np.mean(F1)
            mIoU = np.mean(F1)
            print('===> mIoU: %.2f mean F1: %.2f OA: %.2f' % (mIoU*100, mF1*100, OA*100))
            f.write('===> mIoU: %.2f mean F1: %.2f OA: %.2f\n' % (mIoU*100, mF1*100, OA*100))

            mean_f1 = F1[1].item()

            if mean_f1 > F1_best:
                F1_best = mean_f1
                patience_counter = 0
                print('Save Best F1 Model')
                f.write('Save Best F1 Model\n')
                best_name = 'best_model.pth'
                torch.save(student.state_dict(), os.path.join(snapshot_dir, best_name))
            else:
                patience_counter += 1
                print(f"Early Stop Patience: {patience_counter}/{args.patience}")
                f.write(f"Early Stop Patience: {patience_counter}/{args.patience}\n")
                if patience_counter >= args.patience:
                    print("Early stopping triggered.")
                    early_stop_triggered = True

    # ====== TESTING ======
    model_name = 'best_model.pth'
    saved_state_dict = torch.load(os.path.join(snapshot_dir, model_name))
    student.load_state_dict(saved_state_dict)

    print('Testing..........')
    f.write('Testing..........\n')

    student.eval()
    TP_all = np.zeros((args.num_classes, 1))
    FP_all = np.zeros((args.num_classes, 1))
    TN_all = np.zeros((args.num_classes, 1))
    FN_all = np.zeros((args.num_classes, 1))
    n_valid_sample_all = 0
    F1 = np.zeros((args.num_classes, 1))
    IoU = np.zeros((args.num_classes, 1))

    tbar = tqdm(test_loader)
    for _, batch in enumerate(tbar):
        image, label, _, _ = batch
        label = label.squeeze().numpy()
        image = image.float().to(device)

        with torch.no_grad():
            pred = student(image)

        _, pred = torch.max(interp(nn.functional.softmax(pred, dim=1)).detach(), 1)
        pred = pred.squeeze().data.cpu().numpy()

        TP, FP, TN, FN, n_valid_sample = eval_image(pred.reshape(-1), label.reshape(-1), args.num_classes)
        TP_all += TP
        FP_all += FP
        TN_all += TN
        FN_all += FN
        n_valid_sample_all += n_valid_sample

    OA = np.sum(TP_all)*1.0 / n_valid_sample_all
    for i in range(args.num_classes):
        P = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + epsilon)
        R = TP_all[i]*1.0 / (TP_all[i] + FN_all[i] + epsilon)
        F1[i] = 2.0*P*R / (P + R + epsilon)
        IoU[i] = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + FN_all[i] + epsilon)

        if i == 1:
            print('===>' + name_classes[i] + ' Precision: %.2f' % (P * 100))
            print('===>' + name_classes[i] + ' Recall: %.2f' % (R * 100))
            print('===>' + name_classes[i] + ' IoU: %.2f' % (IoU[i] * 100))
            print('===>' + name_classes[i] + ' F1: %.2f' % (F1[i] * 100))
            f.write('===>' + name_classes[i] + ' Precision: %.2f\n' % (P * 100))
            f.write('===>' + name_classes[i] + ' Recall: %.2f\n' % (R * 100))
            f.write('===>' + name_classes[i] + ' IoU: %.2f\n' % (IoU[i] * 100))
            f.write('===>' + name_classes[i] + ' F1: %.2f\n' % (F1[i] * 100))

    mF1 = np.mean(F1)
    mIoU = np.mean(F1)

    f.close()
    saved_state_dict = torch.load(os.path.join(snapshot_dir, model_name))

    test_info = {'Precision': P*100, 'Recall': R*100, 'IoU': IoU[i] * 100, 'F1': F1[1] * 100}
    np.savez(os.path.join(snapshot_dir, 'test_info.npz'), **test_info)

if __name__ == '__main__':
    main()
