import argparse
import json
import numpy as np
import time
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
import torch.backends.cudnn as cudnn
from utils.tools import adjust_learning_rate, label_accuracy_score, eval_image
from dataset.Sen2Fire_Dataset import Sen2FireDataSet
from ptflops import get_model_complexity_info

from models.FCN import FCN8s
from models.SegNet import SegNet
from models.unet import unet
import random
from tqdm import tqdm

from Compressed_methods.SegNet_Prune import prune_segnet_model
from Compressed_methods.unet_Prune import prune_unet_model
from Compressed_methods.FCN_Prune import prune_fcn8s_model

name_classes = np.array(['non-fire','fire'], dtype=str)
epsilon = 1e-14

class F1Loss(nn.Module):
    def __init__(self, epsilon=1e-7):
        super(F1Loss, self).__init__()
        self.epsilon = epsilon

    def forward(self, logits, targets):
        num_classes = logits.size(1)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()

        probs = torch.softmax(logits, dim=1)
        probs = probs.view(logits.shape[0], num_classes, -1)
        targets_onehot = targets_onehot.view(targets.shape[0], num_classes, -1)

        TP = (probs * targets_onehot).sum(dim=2)
        FP = (probs * (1 - targets_onehot)).sum(dim=2)
        FN = ((1 - probs) * targets_onehot).sum(dim=2)

        precision = TP / (TP + FP + self.epsilon)
        recall = TP / (TP + FN + self.epsilon)
        f1 = 2 * precision * recall / (precision + recall + self.epsilon)

        return 1 - f1.mean()

def init_seeds(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed) 
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 

def get_arguments():
    parser = argparse.ArgumentParser()
    
    # Model
    parser.add_argument("--model", type=str, default='unet',
                        help="model name.")
    parser.add_argument("--num_layer", type=int, default=64,
                        help="basic layer number.")
    parser.add_argument("--prune_ratio", type=float, default=0.3,
                        help="prune ratio (剪掉的比例).")
    parser.add_argument("--pretrained_model", type=str, required=True,
                        help="name of pretrained model folder (under snapshot_dir).")
    parser.add_argument("--fine_tune_steps", type=int, default=1000,
                        help="number of fine-tuning steps after pruning.")
    
    # Dataset
    parser.add_argument("--data_dir", type=str, default='./dataset/Sen2Fire/',
                        help="dataset path.")
    parser.add_argument("--train_list", type=str, default='./dataset/train.txt',
                        help="training list file.")
    parser.add_argument("--val_list", type=str, default='./dataset/val.txt',
                        help="val list file.")
    parser.add_argument("--test_list", type=str, default='./dataset/test.txt',
                        help="test list file.")
    parser.add_argument("--num_classes", type=int, default=2,
                        help="number of classes.")
    parser.add_argument("--mode", type=int, default=5,
                        help="input type (0-all_bands, 1-all_bands_aerosol,...).")
    parser.add_argument("--batch_size", type=int, default=8,
                        help="number of images in each batch.")
    parser.add_argument("--num_workers", type=int, default=1,
                        help="number of workers for dataloader.")
    
    # Training
    parser.add_argument("--learning_rate", type=float, default=1e-4,
                        help="base learning rate.")
    parser.add_argument("--weight_decay", type=float, default=5e-4,
                        help="regularisation parameter for L2-loss.")
    parser.add_argument("--weight", type=float, default=10,
                        help="ce weight.")
    parser.add_argument("--lamda_F1loss", type=float, default=1,
                        help="lamda_F1loss.")

    # Result
    parser.add_argument("--snapshot_dir", type=str, default='./Exp/',
                        help="where to save snapshots of the model.")

    return parser.parse_args()

modename = ['all_bands', 'all_bands_aerosol', 'rgb', 'rgb_aerosol',
            'swir', 'swir_aerosol', 'nbr', 'nbr_aerosol',
            'ndvi', 'ndvi_aerosol', 'rgb_swir_nbr_ndvi', 'rgb_swir_nbr_ndvi_aerosol']

input_dim_lookup = {i: (12 if i == 0 else 13 if i == 1 else 3 if i % 2 == 0 else 4) for i in range(10)}
input_dim_lookup[10] = 6
input_dim_lookup[11] = 7

def main():
    args = get_arguments()
    # 原始 full model 路径
    base_snapshot_dir = os.path.join(args.snapshot_dir,
                                     'input_' + modename[args.mode],
                                     args.model.lower(),
                                     'layer_num_' + str(args.num_layer))
    pretrained_path = os.path.join(base_snapshot_dir, args.pretrained_model, 'best_model.pth')
    # 剪枝后的新目录
    snapshot_dir = os.path.join(args.snapshot_dir,
                                'input_' + modename[args.mode],
                                args.model.lower(),
                                'layer_num_' + str(args.num_layer) + '_Prune_' + str(args.prune_ratio),
                                'weight_' + str(args.weight) + '_time' + time.strftime('%m%d_%H%M', time.localtime(time.time()))
                                )
    os.makedirs(snapshot_dir, exist_ok=True)

    if not os.path.exists(pretrained_path):
        raise FileNotFoundError(f"Pretrained model not found: {pretrained_path}")

    with open(os.path.join(snapshot_dir, 'args.json'), 'w', encoding='utf-8') as f_json:
        json.dump(vars(args), f_json, indent=2, ensure_ascii=False)
    
    f = open(os.path.join(snapshot_dir, 'Training_log.txt'), 'w')
    
    cudnn.enabled = True
    cudnn.benchmark = True
    init_seeds()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create network
    input_dim = input_dim_lookup[args.mode]
    model_name = args.model.lower()
    
    if args.model.lower() == "unet":
        model = unet(n_classes=args.num_classes, n_channels=input_dim, layer_num=args.num_layer)
    elif args.model.lower() == "fcn8s":
        model = FCN8s(n_classes=args.num_classes, n_channels=input_dim, layer_num=args.num_layer)
    elif args.model.lower() == "segnet":
        model = SegNet(n_classes=args.num_classes, n_channels=input_dim, layer_num=args.num_layer)
    else:
        raise ValueError(f"Unsupported model type: {args.model}")
    
    # Load pretrained
    print(f"Loading pretrained model from {pretrained_path}")
    state_dict = torch.load(pretrained_path, map_location='cpu')
    model.load_state_dict(state_dict)
    
    # Prune
    if args.model.lower() == "segnet":
        model = prune_segnet_model(model, prune_ratio=args.prune_ratio)
    elif args.model.lower() == "unet":
        model = prune_unet_model(model, prune_ratio=args.prune_ratio)
    elif args.model.lower() == "fcn8s":
        model = prune_fcn8s_model(model, prune_ratio=args.prune_ratio)
    
    input_size_train = (512, 512)
    input_res = (input_dim, input_size_train[1], input_size_train[0])
    macs, params = get_model_complexity_info(model, input_res, as_strings=False, print_per_layer_stat=False, verbose=False
    )
    FLOPS = macs*2
    model_complexity_info = {'GFLOP': FLOPS/(1e9), 'parameter_GB': params/(1024**3)}    
    np.savez(os.path.join(snapshot_dir, 'model_complexity_info.npz'), **model_complexity_info)

    print(f"[INFO] Total FLOP: {FLOPS/(1e9)}, GFLOPs")
    print(f"[INFO] Total trainable parameters: {params/(1024**3)}, GB")
    print(f"[INFO] Mode: {modename[args.mode]}")
    print(f"[INFO] Start training {model_name}")
    
    f.write(f"[INFO] Total FLOP: {FLOPS/(1e9)}, GFLOPs\n")
    f.write(f"[INFO] Total trainable parameters: {params/(1024**3)} GB\n")
    f.write(f"[INFO] Mode: {modename[args.mode]}\n")
    f.write(f"[INFO] Start training {model_name}\n")

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    interp = nn.Upsample(size=(512, 512), mode='bilinear')

    # Dataset
    train_loader = data.DataLoader(
        Sen2FireDataSet(args.data_dir, args.train_list, max_iters=args.fine_tune_steps*args.batch_size, mode=args.mode),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True
    )
    test_loader = data.DataLoader(
        Sen2FireDataSet(args.data_dir, args.test_list, mode=args.mode),
        batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True
    )

    class_weights = [1, args.weight]
    L_ce = nn.CrossEntropyLoss(weight=torch.Tensor(class_weights).to(device))
    L_f1 = F1Loss()

    # Fine-tune
    print(f"Fine-tuning pruned model for {args.fine_tune_steps} steps...")
    model.train()
    for step, (patches, labels, _, _) in enumerate(train_loader):
        if step >= args.fine_tune_steps:
            break
        optimizer.zero_grad()
        patches, labels = patches.to(device), labels.to(device).long()
        pred = model(patches)
        pred_interp = interp(pred)
        lamda = args.lamda_F1loss
        L_seg_value = (1-lamda) * L_ce(pred_interp, labels) + lamda * L_f1(pred_interp, labels)
        L_seg_value.backward()
        optimizer.step()
        if (step+1) % 50 == 0:
            msg = f"[Fine-tune] Step {step+1}/{args.fine_tune_steps}, Loss: {L_seg_value.item():.4f}"
            print(msg)
            f.write(msg + '\n')
            f.flush()

    torch.save(model.state_dict(), os.path.join(snapshot_dir, "pruned_finetuned_model.pth"))
    print("Pruned and fine-tuned model saved!")
    f.write("Pruned and fine-tuned model saved!\n")

    # === Test 部分 ===
    print('Testing..........')  
    f.write('Testing..........\n')  

    model.eval()
    TP_all = np.zeros((args.num_classes, 1))
    FP_all = np.zeros((args.num_classes, 1))
    TN_all = np.zeros((args.num_classes, 1))
    FN_all = np.zeros((args.num_classes, 1))
    n_valid_sample_all = 0
    F1 = np.zeros((args.num_classes, 1))
    IoU = np.zeros((args.num_classes, 1))
    
    tbar = tqdm(test_loader)
    for _, batch in enumerate(tbar):  
        image, label,_,_ = batch
        label = label.squeeze().numpy()
        image = image.float().to(device)
        
        with torch.no_grad():
            pred = model(image)

        _,pred = torch.max(interp(nn.functional.softmax(pred,dim=1)).detach(), 1)
        pred = pred.squeeze().data.cpu().numpy()                       
        
        TP,FP,TN,FN,n_valid_sample = eval_image(pred.reshape(-1),label.reshape(-1),args.num_classes)
      
        TP_all += TP
        FP_all += FP
        TN_all += TN
        FN_all += FN
        n_valid_sample_all += n_valid_sample

    OA = np.sum(TP_all)*1.0 / n_valid_sample_all
    for i in range(args.num_classes):
        P = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + epsilon)
        R = TP_all[i]*1.0 / (TP_all[i] + FN_all[i] + epsilon)
        F1[i] = 2.0*P*R / (P + R + epsilon)
        IoU[i] = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + FN_all[i] + epsilon)
 
        if i==1:
            print('===>' + name_classes[i] + f' Precision: {P[0]*100:.2f}')
            print('===>' + name_classes[i] + f' Recall: {R[0]*100:.2f}')            
            print('===>' + name_classes[i] + f' IoU: {IoU[i][0]*100:.2f}')              
            print('===>' + name_classes[i] + f' F1: {F1[i][0]*100:.2f}')   
            f.write(f'===>{name_classes[i]} Precision: {P[0]*100:.2f}\n')
            f.write(f'===>{name_classes[i]} Recall: {R[0]*100:.2f}\n')            
            f.write(f'===>{name_classes[i]} IoU: {IoU[i][0]*100:.2f}\n')              
            f.write(f'===>{name_classes[i]} F1: {F1[i][0]*100:.2f}\n')   
    
    mF1 = np.mean(F1)   
    mIoU = np.mean(IoU)           
    print(f'===> mIoU: {mIoU*100:.2f} mean F1: {mF1*100:.2f} OA: {OA*100:.2f}')
    f.write(f'===> mIoU: {mIoU*100:.2f} mean F1: {mF1*100:.2f} OA: {OA*100:.2f}\n')

    test_info = {'Precision': float(P*100), 'Recall': float(R*100), 'IoU': float(IoU[i]*100), 'F1': float(F1[1]*100)}    
    np.savez(os.path.join(snapshot_dir, 'test_info.npz'), **test_info)

    f.close()

if __name__ == '__main__':
    main()
