import os
import matplotlib.pyplot as plt
import tqdm
import json
import torch
import utils
import argparse
import numpy as np
import textblob
from textblob import TextBlob
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from modules.ctc_decode import Decoder
from torch.utils.data import DataLoader
from datasets.wym_ctc_sentences import MouthActionSentence3D
from torch.optim.lr_scheduler import CosineAnnealingLR
from openpyxl import Workbook
from augmentations import random_rotate_translate_batch, fgsm_attack, pgd_attack

import models.sentence_classification as Models

def main(args):
    import os
    os.makedirs(args.save_model, exist_ok=True)
    # 处理保存路径（确保是目录）
    loss_dir = args.save_loss
    if loss_dir.endswith(".txt"): # 如果传入的是文件路径
       loss_dir = os.path.dirname(loss_dir)
    if loss_dir == "": # 如果只有文件名，没有目录
       loss_dir = "."
    os.makedirs(loss_dir, exist_ok=True)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    textblob.en.spelling.update({"upcoming":100000})

    device = torch.device('cuda')

    print('Loading Data...')

    train_dataset = MouthActionSentence3D(
        root=args.data_path,
        dataset=args.train_dataset, 
        num_points=args.num_points,
        padding=False
    )
    test_dataset = MouthActionSentence3D(
        root=args.data_path,
        dataset=args.test_dataset,
        num_points=args.num_points,
        padding=False
    )

    print('Creating Data Loaders...')
    
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.workers,
                                  pin_memory=True,
                                  collate_fn=utils.ctc_collate)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 pin_memory=True,
                                 collate_fn=utils.ctc_collate)
    
    Model = getattr(Models, args.model)
    num_classes = len(train_dataset.letters) + 1  # 字符表长度 + blank
    # ---- Debug 检查 d_model 与 num_heads ----
    d_model = args.dim * 64 # 与 ConformerEncoder 保持一致
    print(f"[DEBUG] d_model = {d_model}, num_heads = {args.heads}")
    if d_model % args.heads != 0:
        raise ValueError(f"[Error] d_model ({d_model}) must be divisible by num_heads ({args.heads})")
    else:
        print("[DEBUG] d_model 与 num_heads 匹配，可以继续运行")
    # -----------------------------------------
    model = Model(channel=args.channel,
                  in_planes=args.in_planes,
                  radius=args.radius,
                  nsamples=args.nsamples,
                  spatial_stride=args.spatial_stride,
                  temporal_kernel_size=args.temporal_kernel_size,
                  temporal_stride=args.temporal_stride,
                  dim=args.dim,
                  depth=args.depth,
                  heads=args.heads,
                  dim_head=args.dim_head,
                  dropout1=args.dropout1,
                  mlp_dim=args.mlp_dim,
                  num_points=args.num_points,
                  dropout2=args.dropout2,
                  num_classes=num_classes)
    # if any(os.scandir(args.save_model)):
    #     print('Loading Model from checkpoints...')
    #     weights = torch.load(args.save_model+'model.pth', map_location=device)
    #     model.load_state_dict(weights)
    # else:
    print('Creating Model from scratch...')
        
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device=device)

    lr = args.lr
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
    
    lr_scheduler = CosineAnnealingLR(optimizer,
                                     T_max=args.epochs*int(len(train_dataloader)),
                                     eta_min=1e-5)
    
    criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True).to(device)
    decoder = Decoder(train_dataset.letters)

    print('Start Training')
    train_losses = []
    val_losses = []
    cer_list = []
    wer_list = []

    best_val_loss = float('inf') #chushihua
    if os.path.isfile(args.save_loss):
        try:
            with open(args.save_loss) as f:
                line = f.readline()
                if 'loss:' in line:
                    loss = line.split(':')[-1].strip()
                    best_val_loss = float(loss)
        except Exception as e:
            print(f"Warning: Could not read previous loss file. Using inf. Error: {e}")
    print('Best evaluation loss is: ', best_val_loss)
    best_wer, best_cer = 1, 1
    best_wer_corr, best_cer_corr = 1, 1

    for epoch in range(args.epochs):
        print('Epoch ', epoch)
        model.train()
        train_loss = 0
        for video, text, video_len, text_len, user, position in tqdm(train_dataloader):
            # --- 开始训练 batch ---
            video, text = video.to(device), text.to(device)

            # 1) 常规随机数据增强（在线 augmentation）
            if args.use_aug:
                # video shape: (B, T, N, C) 假定前三列为 xyz
                video = random_rotate_translate_batch(video, max_rot_deg=args.rot_deg, max_translate=args.trans_range)

            optimizer.zero_grad()

            # If adversarial training enabled, create adv example first (FGSM or PGD)
            if args.adv_train:
                # We will use the provided attack functions. They expect model and criterion to accept appropriate args.
                # Prepare lengths as tensors (same as later)
                # Note: adjust video_len/text_len to tensors on device before passing to attack funcs
                video_len_tensor = torch.tensor(video_len, device=device)
                text_len_tensor = torch.tensor(text_len, device=device)

                if args.adv_method == 'fgsm':
                    adv_video = fgsm_attack(model, criterion, video, text,
                                            video_len_tensor, text_len_tensor,
                                            eps=args.adv_eps, alpha=args.adv_alpha, device=device)
                else:
                    adv_video = pgd_attack(model, criterion, video, text,
                                           video_len_tensor, text_len_tensor,
                                           eps=args.adv_eps, alpha=args.adv_alpha, steps=args.adv_steps, device=device)

                # Combine clean + adv (e.g., use only adv, or mix)
                # Option A: use adv only
                train_input = adv_video
                # Option B: mix half clean, half adv -- left as future extension
            else:
                train_input = video

            # forward
            logits = model(train_input).transpose(0, 1)

            # ensure lengths are tensors & aligned
            video_len = torch.tensor(video_len, device=logits.device)
            text_len = torch.tensor(text_len, device=logits.device)
            max_len = logits.size(0)
            video_len = torch.clamp(video_len, max=max_len)

            # compute loss
            with torch.backends.cudnn.flags(enabled=False):
                loss_all = criterion(logits.log_softmax(-1), text, video_len, text_len)
            loss = loss_all.mean()

            # backward & step
            if args.adv_train and args.adv_use_standard_backward:
                # standard backprop (recommended when adv enabled)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            else:
                # original dlogits trick (keep if you must)
                optimizer.zero_grad()
                weight = torch.ones_like(loss_all)
                dlogits = torch.autograd.grad(loss_all, logits, grad_outputs=weight)[0]
                logits.backward(dlogits)
                optimizer.step()

            lr_scheduler.step()
            # --- end batch ---
            iter_loss = loss.item()
            train_loss += iter_loss
        train_loss /= len(train_dataloader)
        print(f'Epoch {epoch} Train Loss: {train_loss}')
        val_loss = 0
        model.eval()
        pred, gt, user_order, pos = [], [], [], []
        pred_corr = []
        with torch.no_grad():
            for video_t, text_t, video_len_t, text_len_t, user_t, position_t in test_dataloader:
                video_t, text_t = video_t.to(device), text_t.to(device)
                logits_t = model(video_t).transpose(0, 1)
                with torch.backends.cudnn.flags(enabled=False):
                    loss_all_t = criterion(logits_t.log_softmax(-1), text_t, 
                                        video_len_t, text_len_t)
                loss_t = loss_all_t.mean()
                iter_loss_t = loss_t.item()
                val_loss += iter_loss_t
                decoded = decoder.decode_greedy(logits_t, video_len_t)
                target = decoder.idx_to_string(text_t, text_len_t, test_dataset.letters)
                print('True: ', target)
                print('Pred: ', decoded, '\n')
                pred.extend(decoded)
                gt.extend(target)
                user_order.extend(user_t)
                pos.extend(position_t)
            
            for p in pred:
                corrected = [str(TextBlob(s.lower()).correct()).upper() for s in p.split()]
                pred_corr.append(' '.join(corrected))

            cer = decoder.compute_cer_level_distance(pred, gt)
            wer = decoder.compute_wer_level_distance(pred, gt)

            cer_corr = decoder.compute_cer_level_distance(pred_corr, gt)
            wer_corr = decoder.compute_wer_level_distance(pred_corr, gt)

            if cer < best_cer:
                best_cer = cer
                torch.save(model.state_dict(), args.save_model+'model_best_cer.pth')
            
            if wer < best_wer:
                best_wer = wer
                best_performance = {
                    'best_wer': best_wer,
                    'best_cer': best_cer,
                    'user_order': user_order,
                    'positions': pos,
                    'prediction': pred,
                    'target': gt
                }
                json_acc = json.dumps(best_performance, indent=4)
                with open(args.save_performance, "w") as outfile:
                    outfile.write(json_acc)

            if cer_corr < best_cer_corr:
                best_cer_corr = cer_corr
            if wer_corr < best_wer_corr:
                best_wer_corr = wer_corr

            val_loss /= len(test_dataloader)
            if val_loss < best_val_loss:
                print('Model Improved! Save the current Model!')
                best_val_loss = val_loss
                with open(args.save_loss, 'w') as f:
                    f.write('best evaluation loss:'+str(best_val_loss))
            train_losses.append(train_loss)
            val_losses.append(val_loss if val_loss is not None else float('nan'))
            cer_list.append(cer if cer is not None else float('nan'))
            if 'wer' in locals():
                wer_list.append(wer if wer is not None else float('nan'))
            torch.save(model.state_dict(), args.save_model+'model_sentence_current.pth')
            print(f'Epoch {epoch} Test Loss: {val_loss}')
            print(f'Word Error Rate: {wer}, Character Error Rate: {cer}')
            print(f'Corrected Word Error Rate: {wer_corr}, Corrected Character Error Rate: {cer_corr}')
            print(f'Best Word Error Rate: {best_wer}, Best Character Error Rate: {best_cer}')
            print(f'Best Corrected Word Error Rate: {best_wer_corr}, Best Corrected Character Error Rate: {best_cer_corr}')
    print('End Training!')
    # 训练循环结束
    log_file = os.path.join(loss_dir, "training_log.txt")
    with open(log_file, "w") as f:
        f.write("Epoch\tTrain_Loss\tVal_Loss\tCER\n")
        for i in range(len(train_losses)):
            val_loss_str = f"{val_losses[i]:.4f}" if not (val_losses[i] != val_losses[i]) else "-"
            cer_str = f"{cer_list[i]:.2f}%" if not (cer_list[i] != cer_list[i]) else "-"
            f.write(f"{i + 1}\t{train_losses[i]:.4f}\t{val_loss_str}\t{cer_str}\n")

    excel_file = os.path.join(loss_dir,"training_log.xlsx")
    wb = Workbook()
    ws = wb.active
    ws.title = "Training Log"
    ws.append(["Epoch", "Train_Loss", "Val_Loss", "CER"])


    for i in range(len(train_losses)):
        val_loss_val = val_losses[i] if not (val_losses[i] != val_losses[i]) else None
        cer_val = cer_list[i] if not (cer_list[i] != cer_list[i]) else None
        ws.append([i + 1, train_losses[i], val_loss_val, cer_val])

    wb.save(excel_file)
    print(f"✅ Excel log saved to: {excel_file}")
    # Loss 曲线
    plt.figure()
    plt.plot(range(1, len(train_losses)+1), train_losses, label="Train Loss")
    plt.plot(range(1, len(val_losses)+1), val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss Curve")
    plt.legend()
    plt.savefig(os.path.join(loss_dir, "loss_curve.png"))
    plt.close()

    # CER 曲线
    plt.figure()
    plt.plot(range(1, len(cer_list)+1), cer_list, label="CER")
    plt.xlabel("Epoch")
    plt.ylabel("CER")
    plt.title("Character Error Rate")
    plt.legend()
    plt.savefig(os.path.join(loss_dir, "cer_curve.png"))
    plt.close()

    print(f"✅ Logs saved to: {log_file}")
    print(f"✅ Loss curve saved to: {os.path.join(loss_dir, 'loss_curve.png')}")
    print(f"✅ CER curve saved to: {os.path.join(loss_dir, 'cer_curve.png')}")

def parse_args():
    parser = argparse.ArgumentParser(description='Sentence Point Cloud Lipreading')

    parser.add_argument('--data-path', default='/home/ubuntu-user/Documents/AtticusDon/', type=str)
    parser.add_argument('--train-dataset', default='Sentences/Train.txt', type=str)
    parser.add_argument('--test-dataset', default='Sentences/Test.txt', type=str)
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--model', default='DepthSpeechRecognition', type=str)
    parser.add_argument('--save-model', default='./Log/checkpoints/', type=str)
    parser.add_argument('--save-loss', default='./Log/checkpoints/loss.txt', type=str)
    parser.add_argument('--save-performance', default='./Log/checkpoints/best_performance.json', type=str)
    # input
    parser.add_argument('--num-points', default=1024, type=int, metavar='N')
    # TNet
    parser.add_argument('--channel', default=6, type=int, help='number of channels')
    # P4D
    parser.add_argument('--in-planes', default=3, type=float)
    parser.add_argument('--radius', default=0.05, type=float)
    parser.add_argument('--nsamples', default=32, type=int)
    parser.add_argument('--spatial-stride', default=16, type=int)
    parser.add_argument('--temporal-kernel-size', default=3, type=int)
    parser.add_argument('--temporal-stride', default=1, type=int)
    # conformer
    parser.add_argument('--dim', default=32, type=int)
    parser.add_argument('--depth', default=5, type=int)
    parser.add_argument('--heads', default=8, type=int)
    parser.add_argument('--dim-head', default=40, type=int)
    parser.add_argument('--mlp-dim', default=160, type=int)
    parser.add_argument('--dropout1', default=0.5, type=float)
    # output
    parser.add_argument('--dropout2', default=0.5, type=float)
    # training
    parser.add_argument('-b', '--batch-size', default=3, type=int)  # At least 2 samples for each GPU
    parser.add_argument('--epochs', default=150, type=int, metavar='N')
    parser.add_argument('-j', '--workers', default=24, type=int, metavar='N')
    parser.add_argument('--lr', default=1e-2, type=float)
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay')
    #spin
    parser.add_argument('--use-aug', action='store_true',
                        help='Use random rotation/translation augmentation during training')
    parser.add_argument('--rot-deg', type=float, default=10.0, help='Max rotation degrees for augmentation')
    parser.add_argument('--trans-range', type=float, default=0.02, help='Max translation (meters) for augmentation')

    parser.add_argument('--adv-train', action='store_true',
                        help='Enable adversarial training (FGSM/PGD) on point coordinates')
    parser.add_argument('--adv-method', choices=['fgsm', 'pgd'], default='fgsm', help='Adversarial method')
    parser.add_argument('--adv-eps', type=float, default=0.02,
                        help='L_inf perturbation budget for adversarial training')
    parser.add_argument('--adv-steps', type=int, default=3, help='PGD steps (if adv-method=pgd)')
    parser.add_argument('--adv-alpha', type=float, default=0.005, help='PGD step size / FGSM step (alpha)')
    # 可选控制：是否对训练 loop 使用标准 backward（adv需要）
    parser.add_argument('--adv-use-standard-backward', action='store_true',
                        help='Use standard backward (loss.backward()) when adv training is enabled. If not set, code attempts to keep original dlogits approach (may be incompatible with adv).')

    args = parser.parse_args()

    return args

if __name__ == "__main__":
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, [0, 1]))
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 'max_split_size_mb:128'
    os.environ["WANDB_MODE"] = "dryrun"
    args = parse_args()
    main(args=args)
