import os
import random
import pyrallis
import json
from tqdm import tqdm

import numpy as np
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter

from loaders.image_loader import load_images
from models.base import load_model
from optimization.optimizer import load_optimizer
from optimization.scheduler import load_lr_scheduler
from loss.task_loss import load_task_loss
from metrics.accuracy import accuracy
from utils.train_util import AverageMeter, ProgressMeter
from utils.log_utils import log_scalar_dict, create_experiment_dir

from options import ModelBaseTrainConfig

@pyrallis.wrap()
def main(cfg: ModelBaseTrainConfig):
    init_seed(2025)
    # ----------------------------------------
    # basic configuration
    # ----------------------------------------
    use_cuda = not cfg.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # ----------------------------------------
    # logging configuration
    # ----------------------------------------

    if not cfg.logging.disable_logging:
        logger = SummaryWriter(log_dir=os.path.join(cfg.logging.log_dir, "runs", cfg.logging.exp_name, cfg.logging.sub_exp_name))
        logger.add_text("Config", json.dumps(pyrallis.encode(cfg), indent=4))
    else:
        logger = None
        
    exp_dir_path = create_experiment_dir(cfg.logging.log_dir, cfg.logging.exp_name, 'base', cfg.logging.sub_exp_name)

    # ----------------------------------------
    # model base configuration
    # ----------------------------------------
    model = load_model(cfg.task.original_model_name, num_classes=cfg.task.num_classes).to(device)

    # ----------------------------------------
    # data loader configuration
    # ----------------------------------------
    train_loader = load_images(cfg.task.data_dir, cfg.task.task_name, data_type='train', batch_size=cfg.batch_size, path_prefix=cfg.task.path_prefix)
    test_loader = load_images(cfg.task.data_dir, cfg.task.task_name, data_type='test', batch_size=cfg.batch_size, path_prefix=cfg.task.path_prefix)

    # ----------------------------------------
    # opimization configuration
    # ----------------------------------------
    criterion = load_task_loss(cfg.task)
    optimizer = load_optimizer(model.parameters(), cfg.optim)
    scheduler = load_lr_scheduler(optimizer, cfg.optim)
    print("SMOOTH FACTOR", cfg.optim.smooth_factor)

    # ----------------------------------------
    # each epoch
    # ----------------------------------------

    best_acc = None
    best_epoch = None
    best_param = None
    
    last_acc = None
    last_param = None
    
    print('\n')
    for epoch in tqdm(range(cfg.epochs + cfg.optim.warmup_epochs), desc='Epoch'):
        if epoch < cfg.optim.warmup_epochs:
            lr = optimizer.param_groups[0]['lr']
            lr = lr / cfg.optim.warmup_epochs * (epoch + 1)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            
        loss, smooth_loss, acc = train(train_loader, model, criterion, optimizer, cfg.optim.smooth_factor, device)
        eval_loss, eval_acc = test(test_loader, model, criterion, device)
        
        loss_acc_dict = {
            'original model training loss': loss.avg,
            'original model smooth loss': smooth_loss.avg,
            'original model training acc': acc.avg,
            'original model test loss': eval_loss.avg,
            'original model test acc': eval_acc.avg
        }
        
        log_scalar_dict(loss_acc_dict,
                        title='model_base_train',
                        iteration=epoch,
                        logger=logger)
        # ----------------------------------------
        # save best model
        # ----------------------------------------
        last_acc = eval_acc.avg
        last_param = model.state_dict()
        if best_acc is None or best_acc < last_acc:
            best_acc = last_acc
            best_epoch = epoch
            best_param = model.state_dict()
        
        if epoch >= cfg.optim.warmup_epochs:
            scheduler.check_and_step(epoch)

    torch.save(last_param, os.path.join(exp_dir_path, 'model.pt'))
    torch.save(best_param, os.path.join(exp_dir_path, 'best_model.pt'))

    logger.add_text("train_model_base", f"BEST ACC: {best_acc}")    
    logger.add_text("train_model_base", f"BEST EPOCH: {best_epoch}")
    logger.add_text("train_model_base", f"LAST ACC: {last_acc}")    
    logger.add_text("train_model_base", f"MODEL DIR: {exp_dir_path}")    
    
    print('BEST ACC', best_acc)
    print('BEST EPOCH', best_epoch)
    print('LAST ACC', last_acc)
    print('MODEL DIR', exp_dir_path)


def train(train_loader, model, criterion, optimizer, smooth_factor, device):
    model.train()

    rec_loss_meter = AverageMeter('Loss', ':.4e')
    smooth_loss_meter = AverageMeter('Loss', ':.4e')
    acc_meter = AverageMeter('Acc', ':6.2f')
    progress = ProgressMeter(total=len(train_loader), step=20, prefix='Train',
                             meters=[rec_loss_meter, acc_meter])

    for i, samples in enumerate(train_loader):
        inputs, labels = samples
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        rec_loss = criterion(outputs, labels)
        
        # 增加smooth loss
        smooth_loss = torch.tensor(0.0)
        if smooth_factor > 0:
            smooth_loss = smoothness_loss_conv([module for module in model.modules() if isinstance(module, nn.Conv2d)], alpha=smooth_factor)
        
        loss = rec_loss + smooth_loss
        
        acc, acc5 = accuracy(outputs, labels, topk=(1, 5))

        rec_loss_meter.update(rec_loss.item(), inputs.size(0))
        smooth_loss_meter.update(smooth_loss.item(), inputs.size(0))
        acc_meter.update(acc.item(), inputs.size(0))

        optimizer.zero_grad()  # 1
        loss.backward()  # 2
        optimizer.step()  # 3

        progress.display(i)

    return rec_loss_meter, smooth_loss_meter, acc_meter


def test(test_loader, model, criterion, device):
    model.eval()

    loss_meter = AverageMeter('Loss', ':.4e')
    acc_meter = AverageMeter('Acc', ':6.2f')
    progress = ProgressMeter(total=len(test_loader), step=20, prefix='Test',
                             meters=[loss_meter, acc_meter])
    
    for i, samples in enumerate(test_loader):
        inputs, labels = samples
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            acc, acc5 = accuracy(outputs, labels, topk=(1, 5))

            loss_meter.update(loss.item(), inputs.size(0))
            acc_meter.update(acc.item(), inputs.size(0))

            progress.display(i)

    return loss_meter, acc_meter

def init_seed(seed = 2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    
def smoothness_loss_conv(conv_layers, alpha=1.0):
    """
    平滑性损失函数：适用于卷积层权重
    参数:
        conv_layers: List[nn.Module]，网络中所有卷积层的集合
        alpha: 平滑性正则化的权重系数
    返回:
        torch.Tensor: 总的平滑性损失值
    """
    loss = 0.0
    num_layers = len(conv_layers)  # 计算卷积层的总数

    for layer in conv_layers:
        if isinstance(layer, nn.Conv2d):  # 仅处理卷积层
            weight = layer.weight  # 权重张量，形状为 [out_channels, in_channels, height, width]

            # 计算每个 kernel 的均值
            kernel_means = weight.mean(dim=(2, 3))  # 形状为 [out_channels, in_channels]

            # 行之间的平滑性损失：相邻行的差的平方和
            row_diff = kernel_means[1:] - kernel_means[:-1]  # 相邻行之间的差
            row_smoothness = torch.sum(row_diff ** 2)

            # 列之间的平滑性损失：相邻列的差的平方和
            col_diff = kernel_means[:, 1:] - kernel_means[:, :-1]  # 相邻列之间的差
            col_smoothness = torch.sum(col_diff ** 2)

            # 该层的总平滑性损失
            layer_loss = row_smoothness + col_smoothness
            loss += layer_loss

    # 对层的数量取平均
    if num_layers > 0:
        loss /= num_layers

    return alpha * loss

if __name__ == '__main__':
    main()
