import argparse
import json
import numpy as np
import time
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
import torch.backends.cudnn as cudnn
from utils.tools import adjust_learning_rate, label_accuracy_score, eval_image
from dataset.Sen2Fire_Dataset import Sen2FireDataSet
from ptflops import get_model_complexity_info

from models.FCN import FCN8s
from models.SegNet import SegNet
from models.unet import unet
import random
from tqdm import tqdm

name_classes = np.array(['non-fire','fire'], dtype=str)
epsilon = 1e-14

class F1Loss(nn.Module):
    def __init__(self, epsilon=1e-7):
        super(F1Loss, self).__init__()
        self.epsilon = epsilon

    def forward(self, logits, targets):
        num_classes = logits.size(1)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()

        probs = torch.softmax(logits, dim=1)
        probs = probs.view(logits.shape[0], num_classes, -1)
        targets_onehot = targets_onehot.view(targets.shape[0], num_classes, -1)

        TP = (probs * targets_onehot).sum(dim=2)
        FP = (probs * (1 - targets_onehot)).sum(dim=2)
        FN = ((1 - probs) * targets_onehot).sum(dim=2)

        precision = TP / (TP + FP + self.epsilon)
        recall = TP / (TP + FN + self.epsilon)
        f1 = 2 * precision * recall / (precision + recall + self.epsilon)

        return 1 - f1.mean()



def init_seeds(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed) 
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 

def get_arguments():

    parser = argparse.ArgumentParser()
    
    # Model
    parser.add_argument("--model", type=str, default='unet',
                        help="model name.")
    parser.add_argument("--num_layer", type=int, default=64,
                        help="basic layer number.")
    parser.add_argument("--LAM", action='store_true',
                        help="whether apply LEO attention map.")
    parser.add_argument("--num_layers_lam", type=int, default=1,
                        help="number of LAM layer.")
    parser.add_argument("--num_components_lam", type=int, default=3,
                        help="number of LAM components.")
    
    # Dataset
    parser.add_argument("--data_dir", type=str, default='./dataset/Sen2Fire/',
                        help="dataset path.")
    parser.add_argument("--train_list", type=str, default='./dataset/train.txt',
                        help="training list file.")
    parser.add_argument("--val_list", type=str, default='./dataset/val.txt',
                        help="val list file.")
    parser.add_argument("--test_list", type=str, default='./dataset/test.txt',
                        help="test list file.")
    parser.add_argument("--num_classes", type=int, default=2,
                        help="number of classes.")
    parser.add_argument("--mode", type=int, default=5,
                        help="input type (0-all_bands, 1-all_bands_aerosol,...).")
    parser.add_argument("--batch_size", type=int, default=8,
                        help="number of images in each batch.")
    parser.add_argument("--num_workers", type=int, default=1,
                        help="number of workers for dataloader.")
    
    # Training
    parser.add_argument("--learning_rate", type=float, default=1e-4,
                        help="base learning rate.")
    parser.add_argument("--weight_decay", type=float, default=5e-4,
                        help="regularisation parameter for L2-loss.")
    parser.add_argument("--weight", type=float, default=10,
                        help="ce weight.")
    parser.add_argument("--num_steps", type=int, default=10000,
                        help="number of training steps.")
    parser.add_argument("--num_steps_stop", type=int, default=10000,
                        help="number of training steps for early stopping.")
    parser.add_argument("--patience", type=int, default=3,
                        help="patience for early stopping (number of validations to wait without improvement).")
    parser.add_argument("--lamda_F1loss", type=float, default=1,
                        help="lamda_F1loss.")

    # Result
    parser.add_argument("--snapshot_dir", type=str, default='./Exp/',
                        help="where to save snapshots of the model.")

    return parser.parse_args()


modename = ['all_bands',                        #0
            'all_bands_aerosol',                #1
            'rgb',                              #2
            'rgb_aerosol',                      #3
            'swir',                             #4
            'swir_aerosol',                     #5
            'nbr',                              #6
            'nbr_aerosol',                      #7   
            'ndvi',                             #8
            'ndvi_aerosol',                     #9 
            'rgb_swir_nbr_ndvi',                #10
            'rgb_swir_nbr_ndvi_aerosol',]       #11

input_dim_lookup = {i: (12 if i == 0 else 13 if i == 1 else 3 if i % 2 == 0 else 4) for i in range(10)}
input_dim_lookup[10] = 6
input_dim_lookup[11] = 7

def main():
    args = get_arguments()
    snapshot_dir = os.path.join(
    args.snapshot_dir,
    'input_' + modename[args.mode],
    args.model.lower() + '_LAM' if args.LAM else args.model.lower() ,
    'layer_num_' + str(args.num_layer),
    'weight_' + str(args.weight) + '_time' + time.strftime('%m%d_%H%M', time.localtime(time.time()))
) + '/'

    if os.path.exists(snapshot_dir)==False:
        os.makedirs(snapshot_dir)
        
    
    with open(os.path.join(snapshot_dir, 'args.json'), 'w', encoding='utf-8') as f:
        json.dump(vars(args), f, indent=2, ensure_ascii=False)
    
    
    f = open(snapshot_dir+'Training_log.txt', 'w')

    

    cudnn.enabled = True
    cudnn.benchmark = True
    init_seeds()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create network
    input_dim = input_dim_lookup[args.mode]
    model_name = args.model.lower() + '_LAM' if args.LAM else args.model.lower()

    if args.model.lower() == "unet":
        model = unet(n_classes=args.num_classes, n_channels=input_dim, layer_num = args.num_layer,
                     LAM = args.LAM, num_layers_lam=args.num_layers_lam, num_components_lam=args.num_components_lam)
        

    elif args.model.lower() == "fcn8s":
        model = FCN8s(n_classes=args.num_classes, n_channels=input_dim, layer_num = args.num_layer,
                     LAM = args.LAM, num_layers_lam=args.num_layers_lam, num_components_lam=args.num_components_lam) 

        
    elif args.model.lower() == "segnet":
        model = SegNet(n_classes=args.num_classes, n_channels=input_dim, layer_num = args.num_layer,
                     LAM = args.LAM, num_layers_lam=args.num_layers_lam, num_components_lam=args.num_components_lam)  
        
    else:
        raise ValueError(f"Unsupported model type: {args.model}")
    
    input_size_train = (512, 512)
    input_res = (input_dim, input_size_train[1], input_size_train[0])
    macs, params = get_model_complexity_info(model, input_res, as_strings=False, print_per_layer_stat=False, verbose=False
    )
    FLOPS = macs*2
    model_complexity_info = {'GFLOP': FLOPS/(1e9), 'parameter_GB': params/(1024**3)}    
    np.savez(os.path.join(snapshot_dir, 'model_complexity_info.npz'), **model_complexity_info)
    
    
    print(f"[INFO] Total FLOP: {FLOPS/(1e9)}, GFLOPs")
    print(f"[INFO] Total trainable parameters: {params/(1024**3)}, GB")
    print(f"[INFO] Mode: {modename[args.mode]}")
    print(f"[INFO] Start training {model_name}")
    
    f.write(f"[INFO] Total FLOP: {FLOPS/(1e9)}, GFLOPs\n")
    f.write(f"[INFO] Total trainable parameters: {params/(1024**3)} GB\n")
    f.write(f"[INFO] Mode: {modename[args.mode]}\n")
    f.write(f"[INFO] Start training {model_name}\n")

    
    model.train()
    model = model.cuda()
    
    train_loader = data.DataLoader(
                    Sen2FireDataSet(args.data_dir, args.train_list, max_iters=args.num_steps_stop*args.batch_size,
                    mode=args.mode),
                    batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    val_loader = data.DataLoader(
                    Sen2FireDataSet(args.data_dir, args.val_list,mode=args.mode),
                    batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    
    test_loader = data.DataLoader(
                    Sen2FireDataSet(args.data_dir, args.test_list,mode=args.mode),
                    batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    
    # interpolation for the probability maps and labels 
    interp = nn.Upsample(size=(input_size_train[1], input_size_train[0]), mode='bilinear')
    
    
    F1_best = -1    

    patience_counter = 0
    early_stop_triggered = False


    class_weights = [1, args.weight]

    
    L_ce = nn.CrossEntropyLoss(weight=torch.Tensor(class_weights).cuda())
    L_f1 = F1Loss()

    for batch_index, train_data in enumerate(train_loader):
        if batch_index==args.num_steps_stop or early_stop_triggered:
            break
        tem_time = time.time()
        adjust_learning_rate(optimizer,args.learning_rate,batch_index,args.num_steps)
        model.train()
        optimizer.zero_grad()
        
        patches, labels, _, _ = train_data
        
        patches = patches.cuda()      
        labels = labels.cuda().long()

        pred = model(patches)           
        pred_interp = interp(pred)
        
        # Segmentation Loss
        lamda = args.lamda_F1loss
        L_seg_value = (1-lamda) * L_ce(pred_interp, labels) + lamda * L_f1(pred_interp, labels)
        _, predict_labels = torch.max(pred_interp, 1)
        lbl_pred = predict_labels.detach().cpu().numpy()
        lbl_true = labels.detach().cpu().numpy()
        metrics_batch = []
        for lt, lp in zip(lbl_true, lbl_pred):
            _,_,mean_iu,_ = label_accuracy_score(lt, lp, n_class=args.num_classes)
            metrics_batch.append(mean_iu)                
        batch_miou = np.nanmean(metrics_batch, axis=0)  
        batch_oa = np.sum(lbl_pred==lbl_true)*1./len(lbl_true.reshape(-1))
                    
        L_seg_value.backward()
        optimizer.step()

        t = time.time() - tem_time

        if (batch_index+1) % 10 == 0: 
            print('Iter %d/%d Time: %.2f Batch_OA = %.2f Batch_mIoU = %.2f CE_loss = %.3f'%(batch_index+1,args.num_steps,t,batch_oa*100,batch_miou*100,L_seg_value.item()))
            f.write('Iter %d/%d Time: %.2f Batch_OA = %.2f Batch_mIoU = %.2f CE_loss = %.3f\n'%(batch_index+1,args.num_steps,t,batch_oa*100,batch_miou*100,L_seg_value.item()))
            f.flush() 

        # evaluation per 500 iterations
        if (batch_index+1) % 500 == 0:            
            print('Validating..........')  
            f.write('Validating..........\n')  

            model.eval()
            TP_all = np.zeros((args.num_classes, 1))
            FP_all = np.zeros((args.num_classes, 1))
            TN_all = np.zeros((args.num_classes, 1))
            FN_all = np.zeros((args.num_classes, 1))
            n_valid_sample_all = 0
            F1 = np.zeros((args.num_classes, 1))
            IoU = np.zeros((args.num_classes, 1))
            
            tbar = tqdm(val_loader)
            for _, batch in enumerate(tbar):  
                image, label,_,_ = batch
                label = label.squeeze().numpy()
                image = image.float().cuda()
                
                with torch.no_grad():
                    pred = model(image)

                _,pred = torch.max(interp(nn.functional.softmax(pred,dim=1)).detach(), 1)
                pred = pred.squeeze().data.cpu().numpy()                       
                               
                TP,FP,TN,FN,n_valid_sample = eval_image(pred.reshape(-1),label.reshape(-1),args.num_classes)
                TP_all += TP
                FP_all += FP
                TN_all += TN
                FN_all += FN
                n_valid_sample_all += n_valid_sample

            OA = np.sum(TP_all)*1.0 / n_valid_sample_all
            for i in range(args.num_classes):
                P = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + epsilon)
                R = TP_all[i]*1.0 / (TP_all[i] + FN_all[i] + epsilon)
                F1[i] = 2.0*P*R / (P + R + epsilon)
                IoU[i] = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + FN_all[i] + epsilon)
            
                if i==1:
                    print('===>' + name_classes[i] + ' Precision: %.2f'%(P * 100))
                    print('===>' + name_classes[i] + ' Recall: %.2f'%(R * 100))            
                    print('===>' + name_classes[i] + ' IoU: %.2f'%(IoU[i] * 100))              
                    print('===>' + name_classes[i] + ' F1: %.2f'%(F1[i] * 100))   
                    f.write('===>' + name_classes[i] + ' Precision: %.2f\n'%(P * 100))
                    f.write('===>' + name_classes[i] + ' Recall: %.2f\n'%(R * 100))            
                    f.write('===>' + name_classes[i] + ' IoU: %.2f\n'%(IoU[i] * 100))              
                    f.write('===>' + name_classes[i] + ' F1: %.2f\n'%(F1[i] * 100))   
                    

            mF1 = np.mean(F1)   
            mIoU = np.mean(F1)           
            print('===> mIoU: %.2f mean F1: %.2f OA: %.2f'%(mIoU*100,mF1*100,OA*100))
            f.write('===> mIoU: %.2f mean F1: %.2f OA: %.2f\n'%(mIoU*100,mF1*100,OA*100))
                
            mean_f1 = F1[1].item()

            if mean_f1 > F1_best:
                F1_best = mean_f1
                patience_counter = 0
                print('Save Best F1 Model')
                f.write('Save Best F1 Model\n')
                model_name = 'best_model.pth'
                torch.save(model.state_dict(), os.path.join(snapshot_dir, model_name))
            else:
                patience_counter += 1
                print(f"Early Stop Patience: {patience_counter}/{args.patience}")
                f.write(f"Early Stop Patience: {patience_counter}/{args.patience}\n")
                if patience_counter >= args.patience:
                    print("Early stopping triggered.")
                    early_stop_triggered = True

    
    saved_state_dict = torch.load(os.path.join(snapshot_dir, model_name))  
    model.load_state_dict(saved_state_dict)

    print('Testing..........')  
    f.write('Testing..........\n')  

    model.eval()
    TP_all = np.zeros((args.num_classes, 1))
    FP_all = np.zeros((args.num_classes, 1))
    TN_all = np.zeros((args.num_classes, 1))
    FN_all = np.zeros((args.num_classes, 1))
    n_valid_sample_all = 0
    F1 = np.zeros((args.num_classes, 1))
    IoU = np.zeros((args.num_classes, 1))
    
    tbar = tqdm(test_loader)
    for _, batch in enumerate(tbar):  
        image, label,_,_ = batch
        label = label.squeeze().numpy()
        image = image.float().cuda()
        
        with torch.no_grad():
            pred = model(image)

        _,pred = torch.max(interp(nn.functional.softmax(pred,dim=1)).detach(), 1)
        pred = pred.squeeze().data.cpu().numpy()                       
                        
        TP,FP,TN,FN,n_valid_sample = eval_image(pred.reshape(-1),label.reshape(-1),args.num_classes)
        
        TP_all += TP
        FP_all += FP
        TN_all += TN
        FN_all += FN
        n_valid_sample_all += n_valid_sample

    OA = np.sum(TP_all)*1.0 / n_valid_sample_all
    for i in range(args.num_classes):
        P = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + epsilon)
        R = TP_all[i]*1.0 / (TP_all[i] + FN_all[i] + epsilon)
        F1[i] = 2.0*P*R / (P + R + epsilon)
        IoU[i] = TP_all[i]*1.0 / (TP_all[i] + FP_all[i] + FN_all[i] + epsilon)
    
        if i==1:
            print('===>' + name_classes[i] + ' Precision: %.2f'%(P * 100))
            print('===>' + name_classes[i] + ' Recall: %.2f'%(R * 100))            
            print('===>' + name_classes[i] + ' IoU: %.2f'%(IoU[i] * 100))              
            print('===>' + name_classes[i] + ' F1: %.2f'%(F1[i] * 100))   
            f.write('===>' + name_classes[i] + ' Precision: %.2f\n'%(P * 100))
            f.write('===>' + name_classes[i] + ' Recall: %.2f\n'%(R * 100))            
            f.write('===>' + name_classes[i] + ' IoU: %.2f\n'%(IoU[i] * 100))              
            f.write('===>' + name_classes[i] + ' F1: %.2f\n'%(F1[i] * 100))   
        
    mF1 = np.mean(F1)   
    mIoU = np.mean(F1)           

    f.close()
    saved_state_dict = torch.load(os.path.join(snapshot_dir, model_name))
    
    test_info = {'Precision': P*100, 'Recall': R*100, 'IoU': IoU[i] * 100, 'F1': F1[1] * 100}    
    np.savez(os.path.join(snapshot_dir, 'test_info.npz'), **test_info)
    
if __name__ == '__main__':
    main()