import os
import sys
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from timeit import default_timer as timer
from logger import Logger, AverageMeter, time_to_str
from models.kind_vit2 import create_kind_fusion_vit
#from feature_extration.src.fusion_dataset import create_fusion_dataloader
from feature_extration.process_data import Paraser, get_cell
#from train_vit_gnn_fusion import VITGNNFusionProcessor
from feature_extration.vitgnn_dataset import VitgnnDataset


def create_fusion_args():
    """创建融合模型的参数解析器，继承process_data的配置"""
    # 首先获取process_data的基础配置
    base_parser = Paraser()
    base_args = base_parser.parser.parse_known_args()[0]

    parser = argparse.ArgumentParser(description='Fusion VIT Training for Multi-task IC Design')

    # 从process_data继承数据相关参数
    parser.add_argument('--data_root', type=str, default=base_args.data_root, help='数据根目录')
    parser.add_argument('--train_list', type=str, default=base_args.train_list, help='训练列表文件')
    parser.add_argument('--test_list', type=str, default=base_args.test_list, help='测试列表文件')
    parser.add_argument('--graph_save_root', type=str, default=base_args.graph_save_root, help='图数据保存根目录')
    parser.add_argument('--lef_path', type=str, default=base_args.lef_path, help='LEF文件路径')
    parser.add_argument('--place_def_root', type=str, default=base_args.place_def_root, help='Place DEF根目录')
    parser.add_argument('--route_def_root', type=str, default=base_args.route_def_root, help='Route DEF根目录')

    # 模型相关参数
    parser.add_argument('--model_name', type=str, default='deit_base_patch16_224', help='基础VIT模型名称')
    parser.add_argument('--in_channels', type=int, default=96, help='输入通道数')
    parser.add_argument('--img_size', type=int, default=256, help='图像尺寸')
    parser.add_argument('--patch_size', type=int, default=16, help='Patch大小')
    parser.add_argument('--drop_rate', type=float, default=0.05, help='Dropout率')
    parser.add_argument('--drop_path_rate', type=float, default=0.1, help='Drop path率')

    # 训练相关参数
    parser.add_argument('--batch_size', type=int, default=8, help='批量大小')
    parser.add_argument('--warmup_epochs', type=int, default=5, help='warmup轮数')  # 损失权重
    parser.add_argument('--congestion_weight', type=float, default=1.0, help='Congestion任务权重')
    parser.add_argument('--drc_weight', type=float, default=1.0, help='DRC任务权重')
    parser.add_argument('--ir_drop_weight', type=float, default=10.0, help='IR drop任务权重')
    parser.add_argument('--epochs', type=int, default=100, help='训练轮数')
    parser.add_argument('--lr', type=float, default=1e-3, help='学习率')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='权重衰减')
    parser.add_argument('--max_grad_norm', type=float, default=1.0, help='梯度裁剪最大范数')

    # KIND VIT特有参数
    parser.add_argument('--gene_size', type=int, default=6, help='基因知识维度')
    parser.add_argument('--embed_dim', type=int, default=9, help='Transformer嵌入维度')
    parser.add_argument('--depth', type=int, default=3, help='Transformer层数')
    parser.add_argument('--num_heads', type=int, default=3, help='注意力头数')
    parser.add_argument('--task_cls_sizes', type=str, default='1,1,1')

    # 添加使用说明
    parser.description = """
KIND多任务训练脚本 - 训练支持知识分流的Vision Transformer

基本用法：
    python train.py                    # 使用默认256维类别知识空间 (128+64+64)
    python train.py --task_cls_sizes 160,80,80  # 使用320维类别知识空间
    python train.py --gene_size 1024 --task_cls_sizes 128,96,96  # 自定义基因和类别知识维度
    """

    # 其他参数
    parser.add_argument('--device', type=str, default='cuda', help='设备类型')
    parser.add_argument('--num_workers', type=int, default=4, help='数据加载进程数')
    parser.add_argument('--save_dir', type=str, default='./results', help='模型保存目录')
    parser.add_argument('--log_dir', type=str, default='./results', help='日志保存目录')
    parser.add_argument('--save_freq', type=int, default=10, help='模型保存频率')
    parser.add_argument('--eval_freq', type=int, default=5, help='验证频率')

    parser.add_argument('--use_fusion_features', action='store_true', default=True, help='使用VIT-GNN融合特征')
    parser.add_argument('--vitgnndataroot', type=str, default='./generated_feature', help='预训练VIT模型路径')
    parser.add_argument('--label', type=str, default='all', help='预训练VIT模型路径')

    parser.add_argument('--pretrain', type=int, default=1, help='是否使用预训练，1则预训练父代模型')
    parser.add_argument('--save_model', type=int, default=0, help='是否更新checkpoint')
    parser.add_argument('--gene_knowledge_path', type=str, default="./results/gene_knowledge_200.pth",help='预训练基因知识路径（默认不使用）')
    parser.add_argument('--cls_knowledge_path', type=str, default="",help='预训练类别知识路径，支持跨任务迁移（默认不使用）') #

    return parser


class MultiTaskLoss(nn.Module):
    """多任务损失函数"""

    def __init__(self, task_weights=None):
        super().__init__()
        self.task_weights = task_weights or {'congestion': 1.0, 'drc': 1.0, 'ir_drop': 1.0}
        self.mse_loss = nn.MSELoss(reduction='sum')  # 改为sum

    def forward(self, predictions, targets, args):
        losses = {}
        if args.label == "thermal":
            losses['total'] = self.mse_loss(predictions, targets)
            return losses
        elif args.label == "congestion":
            losses['congestion'] = self.mse_loss(predictions, targets[:, [0], :, :])
            losses['total'] = losses['congestion']
            return losses
        elif args.label == "drc":
            losses['drc'] = self.mse_loss(predictions, targets[:, [1], :, :])
            losses['total'] = losses['drc']
            return losses
        elif args.label == "ir_drop":
            losses['ir_drop'] = self.mse_loss(predictions, targets[:, [2], :, :])
            losses['total'] = losses['ir_drop']
            return losses
        else:
            losses['congestion'] = self.mse_loss(predictions[:,[0],:,:], targets[:,[0],:,:])
            losses['drc'] = self.mse_loss(predictions[:,[1],:,:], targets[:,[1],:,:])
            losses['ir_drop'] = self.mse_loss(predictions[:,[2],:,:], targets[:,[2],:,:])
            total_loss = (self.task_weights['congestion'] * losses['congestion'] +
                          self.task_weights['drc'] * losses['drc'] +
                          self.task_weights['ir_drop'] * losses['ir_drop'])
            losses['total'] = total_loss
            return losses


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps,
                                    num_cycles=0.5, last_epoch=-1):
    """余弦退火学习率调度器，带warmup"""

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * float(num_cycles) * 2.0 * progress)))

    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)


def train_one_epoch(model, dataloader, criterion, optimizer, scheduler, device, epoch, log, args):
    """训练一个epoch"""
    model.train()

    # 损失计数器
    loss_meters = {
        'congestion': AverageMeter(),
        'drc': AverageMeter(),
        'ir_drop': AverageMeter(),
        'total': AverageMeter()}
    start_time = timer()
    for feature, label in dataloader:
        feature = feature.to(device).float()
        label = label.to(device).float()
        pred = model(feature, args.label)
        losses = criterion(pred, label, args)
        optimizer.zero_grad()
        losses['total'].backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        for key, loss_value in losses.items():
            loss_meters[key].update(loss_value.item(), 1)
    elapsed_time = timer() - start_time
    message = f'TRAIN Epoch {epoch:3d} | ' \
              f'Congestion: {loss_meters["congestion"].avg:.6f} | ' \
              f'drc: {loss_meters["drc"].avg:.6f} | ' \
              f'IR Drop: {loss_meters["ir_drop"].avg:.6f} | ' \
              f'Total: {loss_meters["total"].avg:.6f} | ' \
              f'Time: {time_to_str(elapsed_time, "sec")}'
    log.write(message + '\n')
    return loss_meters['total'].avg

def test_one_epoch(model, dataloader, criterion, optimizer, scheduler, device, epoch, log, args):
    """训练一个epoch"""
    model.eval()

    # 损失计数器
    loss_meters = {
        'congestion': AverageMeter(),
        'drc': AverageMeter(),
        'ir_drop': AverageMeter(),
        'total': AverageMeter()}
    start_time = timer()
    for feature, label in dataloader:
        feature = feature.to(device).float()
        label = label.to(device).float()
        pred = model(feature, args.label)
        losses = criterion(pred, label, args)
        for key, loss_value in losses.items():
            loss_meters[key].update(loss_value.item(), 1)
    elapsed_time = timer() - start_time
    message = f'TRAIN Epoch {epoch:3d} | ' \
              f'Congestion: {loss_meters["congestion"].avg:.6f} | ' \
              f'drc: {loss_meters["drc"].avg:.6f} | ' \
              f'IR Drop: {loss_meters["ir_drop"].avg:.6f} | ' \
              f'Total: {loss_meters["total"].avg:.6f} | ' \
              f'Time: {time_to_str(elapsed_time, "sec")}'
    log.write(message + '\n')
    return loss_meters['total'].avg


def main():
    # 解析参数
    parser = create_fusion_args()

    args = parser.parse_args()
    # 设置设备
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    # 创建保存目录
    os.makedirs(args.save_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)  # 设置日志
    log_file = os.path.join(args.log_dir, f'{time.strftime("%Y%m%d_%H%M%S")}.log')
    log = Logger()
    log.open(log_file, mode='w')
    log.write(f'Training started at {time.strftime("%Y-%m-%d %H:%M:%S")}\n')
    log.write(f'Arguments: {args}\n')

    train_dataset = VitgnnDataset(args.data_root, args.vitgnndataroot, args.train_list, args)
    test_dataset = VitgnnDataset(args.data_root, args.vitgnndataroot, args.test_list, args)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

    patch_points = args.patch_size ** 2
    task_num = 3

    one_cls_size = round(args.embed_dim / 3 / task_num)
    total_cls_size = (one_cls_size * task_num)
    gene_size = args.embed_dim - total_cls_size
    total_cls_size *= patch_points
    gene_size *= patch_points
    args.embed_dim *= patch_points
    one_cls_size *= patch_points

    # 确定输入通道数
    input_channels = train_dataset.features.shape[1]

    task_configs = {
        'congestion': one_cls_size,  # congestion任务的类别知识维度
        'drc': one_cls_size,  # drc任务的类别知识维度
        'ir_drop': one_cls_size  # ir_drop任务的类别知识维度
    }
    total_cls_size = sum(task_configs.values())
    if args.label == "thermal":
        task_configs = {
            'thermal': total_cls_size
        }
    log.write(f'Task configurations: {task_configs}\n')
    log.write(f'Total class knowledge size: {total_cls_size} dims\n')

    model = create_kind_fusion_vit(
        patch_size=args.patch_size,
        img_size=args.img_size,
        in_chans=input_channels,  # 使用像素级特征的通道数
        embed_dim=args.embed_dim,  # Transformer嵌入维度
        depth=args.depth,  # Transformer层数
        num_heads=args.num_heads,  # 注意力头数
        gene_size=gene_size,  # 基因知识维度
        task_configs=task_configs,
        drop_rate=args.drop_rate,
        drop_path_rate=args.drop_path_rate,
        no_embed_class=True,
    )

    model = model.to(device)
    log.write(f'Model created: {model.__class__.__name__}\n')
    log.write(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}\n')

    if args.pretrain == 0:
        #args.gene_knowledge_path = "./results/gene_knowledge.pth"
        #args.cls_knowledge_path = "./results/cls_knowledge.pth"
        print("load parameters")
        if args.gene_knowledge_path and os.path.exists(args.gene_knowledge_path):
            print("loading gene path")
            model.load_gene_knowledge(args.gene_knowledge_path)
        if args.cls_knowledge_path and os.path.exists(args.cls_knowledge_path):
            model.load_cls_knowledge(args.cls_knowledge_path,target_task=args.label)


    # 创建损失函数
    task_weights = {
        'congestion': args.congestion_weight,
        'drc': args.drc_weight,
        'ir_drop': args.ir_drop_weight
    }
    criterion = MultiTaskLoss(task_weights)
    log.write(f'lr: {args.lr}\n')
    param_groups = []
    # 低学习率参数组 (包含"task"或"patch"的层)
    low_lr_params = []
    low_lr_names = []
    # 正常学习率参数组
    normal_lr_params = []
    normal_lr_names = []
    for name, param in model.named_parameters():
        if "task" in name:# or "cls" in name:
            low_lr_params.append(param)
            low_lr_names.append(name)
        else:
            normal_lr_params.append(param)
            normal_lr_names.append(name)
    if low_lr_params:
        param_groups.append({
            'params': low_lr_params,
            'lr': args.lr*100,
            'weight_decay': args.weight_decay
        })
        print(f"高学习率层:")
        for name in low_lr_names:
            print(f"  - {name}")

    if normal_lr_params:
        param_groups.append({
            'params': normal_lr_params,
            'lr': args.lr  * 0.001,
            'weight_decay': args.weight_decay
        })

    # 创建优化器
    optimizer = optim.AdamW(
        param_groups
    )
    # 创建学习率调度器
    epochs = args.epochs
    warmup_epochs = args.warmup_epochs
    num_training_steps = len(train_dataloader) * epochs
    num_warmup_steps = len(train_dataloader) * warmup_epochs
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

    # 训练循环
    for epoch in range(0, epochs):
        # 添加epoch分割线
        separator = f"----------------epoch={epoch}-------------\n"
        log.write(separator)
        # 训练
        train_one_epoch(model, train_dataloader, criterion, optimizer, scheduler, device, epoch, log, args)
        if epoch == 0 or epoch % args.eval_freq == 0:
            test_one_epoch(model, test_dataloader, criterion, optimizer, scheduler, device, epoch, log, args)
    # 训练完成后保存最终模型
    if args.save_model:
        torch.save({
            'epoch': args.epochs - 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'args': args,
        }, os.path.join(args.save_dir, 'final_model.pth'))
        log.write(f'Final model saved after training completion\n')

        # 保存基因知识和类别知识
        model.save_gene_knowledge(os.path.join(args.save_dir, 'gene_knowledge.pth'))
        model.save_cls_knowledge(os.path.join(args.save_dir, 'cls_knowledge.pth'))
        log.write(f'Gene knowledge saved to gene_knowledge.pth\n')
        log.write(f'Class knowledge saved to cls_knowledge.pth\n')

        # 保存任务配置文件
        task_config_data = {
            'task_configs': task_configs,
            'gene_size': args.gene_size,
            'total_cls_size': sum(task_configs.values()),
            'task_names': list(task_configs.keys()),
            'cls_dimensions': list(task_configs.values())
        }
        torch.save(task_config_data, os.path.join(args.save_dir, 'task_config.pth'))
        log.write(f'Task configuration saved to task_config.pth\n')
        log.write(f'Task configs: {task_configs} (total: {sum(task_configs.values())} dims)\n')

        log.write(f'Training completed at {time.strftime("%Y-%m-%d %H:%M:%S")}\n')


if __name__ == '__main__':
    main()
