import torch
import os
import time
from config import Config
from data_utils import get_streetview_loaders
from Model.cann_vae import create_model


def train_one_epoch(model, optimizer, train_loader, config, epoch):
    model.train()
    total_loss = 0
    total_recon_loss = 0
    total_vq_loss = 0
    batch_count = 0

    start_time = time.time()
    for batch_idx, (obs, pos, abs) in enumerate(train_loader):
        obs = obs.to(config.device)
        pos = pos.to(config.device)
        optimizer.zero_grad()
        results = model(obs, pos)
        loss_dict = model.loss_function(*results)
        loss = loss_dict['loss']
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_recon_loss += loss_dict['Reconstruction_Loss'].item()
        total_vq_loss += loss_dict['VQ_Loss'].item()
        batch_count += 1

        if (batch_idx + 1) % 100 == 0:
            end_time = time.time()
            print(f"[Epoch {epoch} | Batch {batch_idx + 1}] Recon: {loss_dict['Reconstruction_Loss'].item():.4f}, "
                  f"VQ: {loss_dict['VQ_Loss'].item():.4f} （耗时 {end_time - start_time:.2f}s）")
            start_time = end_time

    avg_loss = total_loss / batch_count
    avg_recon = total_recon_loss / batch_count
    avg_vq = total_vq_loss / batch_count
    print(f"Epoch {epoch} 训练完成, 重建损失: {avg_recon:.4f}, VQ损失: {avg_vq:.4f} ")
    return avg_loss, avg_recon, avg_vq


def evaluate(model, test_loader, config):
    model.eval()
    recon_loss = 0
    vq_loss = 0
    count = 0

    with torch.no_grad():
        for obs, pos, abs in test_loader:
            obs = obs.to(config.device)
            pos = pos.to(config.device)
            results = model(obs, pos)
            loss_dict = model.loss_function(*results)
            recon_loss += loss_dict['Reconstruction_Loss'].item()
            vq_loss += loss_dict['VQ_Loss'].item()
            count += 1

    avg_recon = recon_loss / count
    avg_vq = vq_loss / count
    print(f"测试 - 重建损失: {avg_recon:.4f}, VQ损失: {avg_vq:.4f}")
    return avg_recon, avg_vq


def save_model_and_visualization(model, optimizer, train_loader, config, epoch, losses, recon_losses, vq_losses,
                                 test_losses, test_mses):
    try:
        vis_path = visualize_sequence_reconstruction(
            model, train_loader, config.device, epoch, config.img_dir
        )
        print(f"重建图像已保存到: {vis_path}")
    except Exception as e:
        print(f"[可视化失败] 原因: {repr(e)}")

    model_path = os.path.join(config.model_dir,
                              f'{config.model_name}_{epoch}.pth')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': losses[-1],
        'training_history': {
            'losses': losses,
            'recon_losses': recon_losses,
            'vq_losses': vq_losses,
            'test_losses': test_losses,
            'test_mses': test_mses
        }
    }, model_path)
    print(f"模型已保存到: {model_path}")


# 可视化序列重建（简单实现）
def visualize_sequence_reconstruction(model, data_loader, device, epoch, save_dir):
    """可视化序列数据的重建效果"""
    import matplotlib.pyplot as plt

    # 创建保存目录
    os.makedirs(save_dir, exist_ok=True)

    model.eval()
    with torch.no_grad():
        # 获取一个批次用于可视化
        for obs, pos, abs in data_loader:
            obs = obs.to(device)
            pos = pos.to(device)

            # 获取重建
            recons, _, _ = model(obs, pos)

            # 只取批次中的第一个序列
            original_seq = obs[0].cpu()  # [S, C, H, W]
            recon_seq = recons[0].cpu()  # [S, C, H, W]

            # 选择3个关键帧进行显示（起始、中间、结束）
            seq_len = original_seq.size(0)
            indices = [0, seq_len // 2, seq_len - 1]

            # 创建图像网格
            fig, axes = plt.subplots(2, 3, figsize=(15, 8))

            for i, idx in enumerate(indices):
                # 显示原始图像
                orig = original_seq[idx].permute(1, 2, 0)
                orig = (orig + 1) / 2.0  # 从[-1,1]转换到[0,1]
                axes[0, i].imshow(torch.clip(orig, 0, 1).cpu().detach())
                # axes[0, i].set_title(f"原始 帧{idx}")
                axes[0, i].axis('off')

                # 显示重建图像
                recon = recon_seq[idx].permute(1, 2, 0)
                recon = (recon + 1) / 2.0  # 从[-1,1]转换到[0,1]
                axes[1, i].imshow(torch.clip(recon, 0, 1).cpu().detach())
                # axes[1, i].set_title(f"重建 帧{idx}")
                axes[1, i].axis('off')

            # 保存图像
            plt.tight_layout()
            save_path = os.path.join(save_dir, f'recon_epoch_{epoch}.png')
            plt.savefig(save_path)
            plt.close()
            return save_path


if __name__ == "__main__":
    setting = 0
    print('setting = ', setting)
    config = Config(setting)

    torch.manual_seed(config.seed)
    torch.backends.cudnn.deterministic = True

    os.makedirs(config.results_dir, exist_ok=True)
    os.makedirs(config.model_dir, exist_ok=True)
    os.makedirs(config.img_dir, exist_ok=True)
    print(f"模型将保存到: {config.model_dir}")
    print(f"图像将保存到: {config.img_dir}")
    print(f"使用设备: {config.device}")

    train_loader, test_loader = get_streetview_loaders(config)
    model = create_model(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    losses, recon_losses, vq_losses, test_losses, test_mses = [], [], [], [], []

    checkpoint_path = os.path.join(config.model_dir, f'{config.model_name}_{34}.pth')
    checkpoint = torch.load(checkpoint_path, map_location=config.device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    for epoch in range(1, config.num_epochs + 1):
        print(f"\n==== Epoch {epoch}/{config.num_epochs} ====")
        loss, recon_loss, vq_loss = train_one_epoch(model, optimizer, train_loader, config, epoch)
        test_recon, test_vq = evaluate(model, test_loader, config)

        losses.append(loss)
        recon_losses.append(recon_loss)
        vq_losses.append(vq_loss)
        test_losses.append(test_recon)
        test_mses.append(test_vq)

        if epoch % config.save_interval == 0:
            save_model_and_visualization(
                model, optimizer, train_loader, config, epoch,
                losses, recon_losses, vq_losses, test_losses, test_mses
            )
