# coding=utf-8
"""
Image Fusion Training Script using GAN
"""
from __future__ import print_function
import argparse
import os
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
import kornia

from dataset import FusionDataset
from Net_83 import FusionNet, CLIPEvaluator
from loss import TextureDetailLoss, optimal_ir_fusion


class TrainingConfig:
    """Training configuration class"""

    def __init__(self):
        self.dataset_name = 'data'
        self.batch_size = 16
        self.test_batch_size = 1
        self.n_epochs = 100
        self.input_nc = 1
        self.output_nc = 1
        self.ngf = 64
        self.ndf = 64
        self.learning_rate = 0.001
        self.beta1 = 0.5
        self.num_threads = 0
        self.random_seed = 123
        self.lambda_weight = 150
        self.alpha = 0.25
        self.save_interval = 30
        self.scheduler_step = 10
        self.scheduler_gamma = 0.5


def str2bool(value):
    """Convert string to boolean"""
    if value.lower() in ['true', '1']:
        return True
    elif value.lower() in ['false', '0']:
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description='Image Fusion Training with PyTorch')

    # Data parameters
    parser.add_argument('--dataset', type=str, default='data', help='Dataset name')
    parser.add_argument('--batchSize', type=int, default=16, help='Training batch size')
    parser.add_argument('--testBatchSize', type=int, default=1, help='Testing batch size')

    # Training parameters
    parser.add_argument('--nEpochs', type=int, default=100, help='Number of epochs to train')
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--beta1', type=float, default=0.5, help='Beta1 for Adam optimizer')

    # Model parameters
    parser.add_argument('--input_nc', type=int, default=1, help='Input image channels')
    parser.add_argument('--output_nc', type=int, default=1, help='Output image channels')
    parser.add_argument('--ngf', type=int, default=64, help='Generator filters in first conv layer')
    parser.add_argument('--ndf', type=int, default=64, help='Discriminator filters in first conv layer')

    # System parameters
    parser.add_argument('--cuda', action='store_true', help='Use CUDA acceleration')
    parser.add_argument('--threads', type=int, default=0, help='Number of data loader threads')
    parser.add_argument('--seed', type=int, default=123, help='Random seed')

    # Loss parameters
    parser.add_argument('--lamb', type=int, default=150, help='Weight on L1 term in objective')
    parser.add_argument('--alpha', type=float, default=0.25, help='Alpha parameter')

    return parser.parse_args()


def setup_device_and_seed(use_cuda, seed):
    """Setup device and random seed"""
    if use_cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    cudnn.benchmark = True
    torch.manual_seed(seed)

    if use_cuda:
        torch.cuda.manual_seed(seed)

    return torch.device("cuda" if use_cuda else "cpu")


def create_data_loader(dataset_path, dataset_name, batch_size, num_threads):
    """Create data loader"""
    print('===> Loading datasets')
    root_path = dataset_path
    dataset = FusionDataset(os.path.join(root_path, dataset_name))

    return DataLoader(
        dataset=dataset,
        num_workers=num_threads,
        batch_size=batch_size,
        shuffle=True
    )


def create_models_and_optimizers(device, learning_rate):
    """Initialize models, loss functions, and optimizers"""
    print('===> Building models')

    # Models
    fusion_model = FusionNet().to(device)
    clip_evaluator = CLIPEvaluator(device=device if device.type == 'cuda' else 'cpu').to(device)

    # Loss functions
    l1_loss = torch.nn.L1Loss()
    ssim_loss = kornia.losses.SSIMLoss(3, reduction='mean')
    texture_loss = TextureDetailLoss(opera='Sobel')

    # Optimizer and scheduler
    optimizer = optim.Adam(fusion_model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    print('---------- Networks initialized -------------')
    print('-----------------------------------------------')

    return {
        'fusion_model': fusion_model,
        'clip_evaluator': clip_evaluator,
        'l1_loss': l1_loss,
        'ssim_loss': ssim_loss,
        'texture_loss': texture_loss,
        'optimizer': optimizer,
        'scheduler': scheduler
    }


def compute_loss(models, img_ir, img_vis, fused, fused_feat, contr_loss):
    """Compute training loss"""
    # Get optimal weights
    w1, w2 = optimal_ir_fusion(img_ir, img_vis)

    # Get quality score
    quality_score = models['clip_evaluator'](fused)

    # Compute losses
    l1_term = (models['l1_loss'](w1 * img_ir, w1 * fused) +
               models['l1_loss'](w2 * img_vis, w2 * fused))

    ssim_term = (models['ssim_loss'](fused, img_ir) +
                 models['ssim_loss'](fused, img_vis))

    texture_term = models['texture_loss'](fused, img_ir, img_vis)

    # Combined loss
    g_loss = l1_term + 15 * (ssim_term + texture_term)
    total_loss = g_loss + contr_loss

    return total_loss, quality_score


def train_epoch(epoch,models, data_loader, device):
    """Train for one epoch"""
    models['fusion_model'].train()

    for iteration, batch in enumerate(data_loader, 1):
        img_ir, img_vis = batch[0].to(device), batch[1].to(device)

        # Forward pass
        models['optimizer'].zero_grad()
        fused, fused_feat, contr_loss = models['fusion_model'](img_ir, img_vis)

        # Compute loss
        loss, quality_score = compute_loss(models, img_ir, img_vis, fused, fused_feat, contr_loss)

        print(f"Epoch {epoch:2d}-Ite {iteration:3d}: Total Loss = {loss:.6f}, Quality = {quality_score.mean():.4f}")
        # Backward pass
        loss.backward()
        models['optimizer'].step()

        # Update memory
        with torch.no_grad():
            models['fusion_model'].memory_fusion.update_memory(
                fused_feat.detach(),
                quality_score.detach()
            )


def save_model(model, epoch, save_interval=30):
    """Save model checkpoint"""
    if epoch % save_interval == 0:
        save_path = f"./Model/"
        os.makedirs(save_path, exist_ok=True)
        torch.save(model, f"{save_path}Pr_G.pth")
        print(f"Model saved at epoch {epoch}")


def main():
    """Main training function"""
    # Parse arguments
    opt = parse_arguments()

    # Setup device
    use_cuda = opt.cuda and torch.cuda.is_available()
    device = setup_device_and_seed(use_cuda, opt.seed)

    # Create data loader
    data_loader = create_data_loader("data/", opt.dataset, opt.batchSize, opt.threads)

    # Create models and optimizers
    models = create_models_and_optimizers(device, opt.lr)

    # Training loop
    print("Starting training...")
    for epoch in range(1, 31):  # Train for 30 epochs as specified in original code
        print(f'Training epoch: {epoch}')

        # Train one epoch
        train_epoch(epoch,models, data_loader, device)

        # Update learning rate
        models['scheduler'].step()

        # Save model
        save_model(models['fusion_model'], epoch)

        print(f'Completed epoch: {epoch}')

    print("Training completed!")


if __name__ == '__main__':
    main()