import argparse
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils import data
import torch.backends.cudnn as cudnn
from utils.tools import eval_image
from dataset.Sen2Fire_Dataset import Sen2FireDataSet

from models.FCN import FCN8s
from models.SegNet import SegNet
from models.unet import unet


from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

name_classes = np.array(['non-fire','fire'], dtype=str)
epsilon = 1e-14

def get_arguments():

    parser = argparse.ArgumentParser()
    
    # Model
    parser.add_argument("--model", type=str, default='unet',
                        help="model name.")
    parser.add_argument("--num_layer", type=int, default=64,
                        help="basic layer number.")
    parser.add_argument("--LAM", action='store_true',
                        help="whether apply LEO attention map.")
    parser.add_argument("--num_layers_lam", type=int, default=1,
                        help="number of LAM layer.")
    parser.add_argument("--num_components_lam", type=int, default=3,
                        help="number of LAM components.")
    parser.add_argument("--model_name", type=str, default='weight_10_time0430_0007',
                        help="trained model.")
    parser.add_argument("--model_dir", type=str, default='./Exp',
                        help="trained model folder path.")
    
    #dataset
    parser.add_argument("--data_dir", type=str, default='./dataset/Sen2Fire/',
                        help="dataset path.")
    parser.add_argument("--train_list", type=str, default='./dataset/train.txt',
                        help="training list file.")
    parser.add_argument("--val_list", type=str, default='./dataset/val.txt',
                        help="val list file.")         
    parser.add_argument("--test_list", type=str, default='./dataset/test.txt',
                        help="test list file.")               
    parser.add_argument("--num_classes", type=int, default=2,
                        help="number of classes.")         
    parser.add_argument("--mode", type=int, default=5,
                        help="input type (0-all_bands, 1-all_bands_aerosol,...).")           

    #network
    parser.add_argument("--batch_size", type=int, default=8,
                        help="number of images in each batch.")
    parser.add_argument("--num_workers", type=int, default=1,
                        help="number of workers for multithread dataloading.")

    #result


    parser.add_argument("--snapshot_dir", type=str, default='./Map/',
                        help="where to save detection results.")

    return parser.parse_args()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

modename = ['all_bands',                        #0
            'all_bands_aerosol',                #1
            'rgb',                              #2
            'rgb_aerosol',                      #3
            'swir',                             #4
            'swir_aerosol',                     #5
            'nbr',                              #6
            'nbr_aerosol',                      #7   
            'ndvi',                             #8
            'ndvi_aerosol',                     #9 
            'rgb_swir_nbr_ndvi',                #10
            'rgb_swir_nbr_ndvi_aerosol',]       #11

input_dim_lookup = {i: (12 if i == 0 else 13 if i == 1 else 3 if i % 2 == 0 else 4) for i in range(10)}
input_dim_lookup[10] = 6
input_dim_lookup[11] = 7

def main():

    args = get_arguments()
    
    
    input_size = (512, 512)

    cudnn.enabled = True
    cudnn.benchmark = True
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create network
    input_dim = input_dim_lookup[args.mode]
    model_name = args.model.lower()

    if args.model.lower() == "unet":
        model = unet(n_classes=args.num_classes, n_channels=input_dim, layer_num = args.num_layer,
                     LAM = args.LAM, num_layers_lam=args.num_layers_lam, num_components_lam=args.num_components_lam)
        

    elif args.model.lower() == "fcn8s":
        model = FCN8s(n_classes=args.num_classes, n_channels=input_dim, layer_num = args.num_layer,
                     LAM = args.LAM, num_layers_lam=args.num_layers_lam, num_components_lam=args.num_components_lam) 

        
    elif args.model.lower() == "segnet":
        model = SegNet(n_classes=args.num_classes, n_channels=input_dim, layer_num = args.num_layer,
                     LAM = args.LAM, num_layers_lam=args.num_layers_lam, num_components_lam=args.num_components_lam)  
    
        
    else:
        raise ValueError(f"Unsupported model type: {args.model}")
        
    print(f"[INFO] Total trainable parameters: {count_parameters(model)}")
    print(f"[INFO] Mode: {modename[args.mode]}")
    print(f"[INFO] Start testing {model_name}")
    
    model_type = args.model + '_LAM' if args.LAM else args.model
    restore_from = os.path.join(args.model_dir, 'input_' + modename[args.mode], model_type, 'layer_num_' + str(args.num_layer), args.model_name, 'best_model.pth')
    saved_state_dict = torch.load(restore_from)  
    model.load_state_dict(saved_state_dict)
    
    
    snapshot_dir = os.path.join(args.snapshot_dir, model_type,'layer_num_' + str(args.num_layer), args.model_name)
        # args.snapshot_dir+args.restore_from.split('/')[2]+'/'
    if os.path.exists(snapshot_dir)==False:
        os.makedirs(snapshot_dir)
    
    
    
    model.eval()
    model = model.cuda()
    
    test_loader = data.DataLoader(
                    Sen2FireDataSet(args.data_dir, args.test_list,mode=args.mode),
                    batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    
    # interpolation for the probability maps and labels 
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    
    # Initialize metrics
    model.eval()
    TP_all = np.zeros((args.num_classes, 1))
    FP_all = np.zeros((args.num_classes, 1))
    TN_all = np.zeros((args.num_classes, 1))
    FN_all = np.zeros((args.num_classes, 1))
    n_valid_sample_all = 0
    F1 = np.zeros((args.num_classes, 1))
    IoU = np.zeros((args.num_classes, 1))
    
    tbar = tqdm(test_loader)
    for _, batch in enumerate(tbar):  
        image, label, _, name = batch
        
        image = image.float().cuda()
        label = label.squeeze().numpy()
        
        # print(image.shape, label.shape)
        with torch.no_grad():
            pred = model(image)

        _, pred = torch.max(interp(nn.functional.softmax(pred, dim=1)).detach(), 1)
        pred = pred.squeeze().data.cpu().numpy().astype('uint8')

        # Compute metrics
        TP, FP, TN, FN, n_valid_sample = eval_image(pred.reshape(-1), label.reshape(-1), args.num_classes)
        TP_all += TP
        FP_all += FP
        TN_all += TN
        FN_all += FN
        n_valid_sample_all += n_valid_sample

        patch_name = name[0].split('/')[1]
        patch_path = os.path.join(snapshot_dir, patch_name)
        np.savez_compressed(patch_path, label=pred)
        
        # break
 
    # Compute final metrics
    F1 = np.zeros((args.num_classes, 1))
    for i in range(args.num_classes):
        P = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + epsilon)
        R = TP_all[i]*1.0 / (TP_all[i] + FN_all[i] + epsilon)
        F1[i] = 2.0*P*R / (P + R + epsilon)
        IoU[i] = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + FN_all[i] + epsilon)
    
        if i==1:
            print('===>' + name_classes[i] + ' Precision: %.2f'%(P * 100))
            print('===>' + name_classes[i] + ' Recall: %.2f'%(R * 100))            
            print('===>' + name_classes[i] + ' IoU: %.2f'%(IoU[i] * 100))              
            print('===>' + name_classes[i] + ' F1: %.2f'%(F1[i] * 100))
        
        with open(os.path.join(snapshot_dir, "Result.txt"), "w") as f:
            f.write('===>' + name_classes[i] + ' Precision: %.2f\n'%(P * 100))
            f.write('===>' + name_classes[i] + ' Recall: %.2f\n'%(R * 100))            
            f.write('===>' + name_classes[i] + ' IoU: %.2f\n'%(IoU[i] * 100))              
            f.write('===>' + name_classes[i] + ' F1: %.2f\n'%(F1[i] * 100))   

    


#%%   
if __name__ == '__main__':
    main()
