import os
import argparse
import torch
import random
import numpy as np

import train
from dataset import load_all_data, load_data_fixed_frequency

from model.Transolver.Transolver_Irregular_3D import Model as Transolver
from model.Transolver.Transolver_plus import Model as Transolver_plus
from model.Transolver.benchmark.Transolver_Structured_Mesh_3D import Model as Transolver_Structured_Mesh_3D
from model.GeoFNO.FNO3d import FNO3d
from model.GeoFNO.GNO import GNO
from model.GeoFNO.GeoFNO import GeoFNO
from model.mlp import MLP
from model.GeoFNO.FFNO import FNOFactorizedMesh3D as FFNO
from model.GeoFNO.FCNO import CNOFactorizedMesh3D as FCNO
from model.LNO.LNO import LNO, LNO_single, LNO_triple
from model.DeepONet.deeponet import DeepONet
from model.FEM import FEMHeatSolver


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def build_Transolver(args):
    return Transolver(
        n_hidden=256, n_layers=8, space_dim=3,
        fun_dim=args.in_channels, n_head=8, mlp_ratio=2,
        out_dim=args.out_channels, slice_num=32, unified_pos=0
    )

def build_Transolver_R(args):
    return Transolver_Structured_Mesh_3D(
        n_hidden=256, n_layers=8, space_dim=3,
        fun_dim=args.in_channels, n_head=8, mlp_ratio=2,
        out_dim=args.out_channels, slice_num=32,
        H=20, W=20, D=20, unified_pos=0
    )

def build_Transolver_plus(args):
    return Transolver_plus(
        n_hidden=256, n_layers=8, space_dim=3,
        fun_dim=args.in_channels, n_head=8, mlp_ratio=2,
        out_dim=args.out_channels, slice_num=32, unified_pos=0
    )

def build_FNO3d(args):
    return FNO3d(
        modes1=12, modes2=12, modes3=8, width=32,
        in_channels=args.in_channels, out_channels=args.out_channels, H=20, W=20, D=20
    )

def build_GNO(args):
    return GNO(width=32, in_channel=args.in_channels, out_channel=args.out_channels, r=1e-8)

def build_GeoFNO(args):
    return GeoFNO(modes1=12, modes2=12, modes3=8, width=32,
                  in_channels=args.in_channels, out_channels=args.out_channels, s=20)

def build_MLP(args):
    return MLP(in_channels=args.in_channels, out_channels=args.out_channels,
               hidden_channels=32, n_layers=4, n_dim=1)

def build_FFNO(args):
    return FFNO(modes_x=12, modes_y=12, modes_z=8,
                input_dim=args.in_channels, output_dim=args.out_channels,
                width=32, n_layers=4, share_weight=False, factor=4,
                n_ff_layers=2, ff_weight_norm=True, layer_norm=False, H=20, W=20, D=20)

def build_FFNO_share(args):
    return FFNO(modes_x=12, modes_y=12, modes_z=8,
                input_dim=args.in_channels, output_dim=args.out_channels,
                width=32, n_layers=4, share_weight=True, factor=4,
                n_ff_layers=2, ff_weight_norm=True, layer_norm=False, H=20, W=20, D=20)

def build_FCNO(args):
    return FCNO(modes_x=12, modes_y=12, modes_z=8,
                input_dim=args.in_channels, output_dim=args.out_channels,
                width=32, n_layers=4, share_weight=False, factor=4,
                n_ff_layers=2, ff_weight_norm=True, layer_norm=False, H=20, W=20, D=20)

def build_LNO(args):
    model_attr = dict(time=False)
    return LNO(x_dim=3, y1_dim=args.in_channels, y2_dim=args.out_channels,
               n_block=4, n_mode=256, n_dim=128,
               n_head=8, n_layer=2, attn='Attention_Vanilla', act='GELU', model_attr=model_attr)

def build_LNO_single(args):
    model_attr = dict(time=False)
    return LNO_single(x_dim=None, y1_dim=args.in_channels, y2_dim=args.out_channels,
                      n_block=4, n_mode=256, n_dim=128,
                      n_head=8, n_layer=2, attn='Attention_Vanilla', act='GELU', model_attr=model_attr)

def build_DeepONet(args):  # 仅支持 out_channels == 1
    return DeepONet(branch_dim=3, trunk_dim=args.in_channels, branch_depth=2, trunk_depth=3, width=32)

def build_FEM(args):
    return FEMHeatSolver(mode=args.mode, num_time_steps=args.out_channels, num_points=args.downsample_count)



MODEL_REGISTRY = {
    'Transolver': build_Transolver,
    'Transolver_R': build_Transolver_R,
    'Transolver_plus': build_Transolver_plus,
    'FNO3d': build_FNO3d,
    'GNO': build_GNO,
    'GeoFNO': build_GeoFNO,
    'MLP': build_MLP,
    'FFNO': build_FFNO,
    'FFNO-share': build_FFNO_share,
    'FCNO': build_FCNO,
    'LNO': build_LNO,
    'LNO_single': build_LNO_single,
    'DeepONet': build_DeepONet,
    'FEM': build_FEM,
}
ALL_MODELS = list(MODEL_REGISTRY.keys())

CRACK_TYPE = {
    '1': 'single',
    '2': 'I-double',
    '3': 'II-double',
    '4': 'III-double',
    '5': 'I-multi',
    '6': 'II-multi',
}


def build_datasets(args):
    print("Loading VTU data once ...")
    if args.OOD in ['high', 'mid', 'low']:
        all_graphs, stats = load_all_data(
            root_dir=args.data_root,
            max_workers=1,   # OOD: splitting by blocks
            data_num=args.data_num,
            downsample_count=args.downsample_count,
            surf_downsample_count=args.surf_downsample_count,
            data_type=args.data_type,
            normalize=True,
            unit_normalize=True,
            use_surf=args.use_surf
        )
        N = len(all_graphs)
        train_graphs, test_graphs = [], []
        block_length = 10
        if args.OOD == 'high':
            split = int((1 - args.test_split) * block_length)
            for i in range(0, len(all_graphs), block_length):
                block = all_graphs[i:i + 10]
                if len(block) < 10:
                    # Ignore incomplete data blocks
                    continue
                train_graphs.extend(block[:split])
                test_graphs.extend(block[split:])
        elif args.OOD == 'low':
            split = int(args.test_split * block_length)
            for i in range(0, len(all_graphs), block_length):
                block = all_graphs[i:i + block_length]
                if len(block) < block_length:
                    # Ignore incomplete data blocks
                    continue
                train_graphs.extend(block[split:])
                test_graphs.extend(block[:split])
        elif args.OOD == 'mid':
            center = block_length // 2
            split = int(args.test_split * block_length / 2)
            for i in range(0, len(all_graphs), block_length):
                block = all_graphs[i:i + block_length]
                if len(block) < block_length:
                    # Ignore incomplete data blocks
                    continue
                test_start = center - split
                test_end = center + split

                train_graphs.extend(block[:test_start] + block[test_end:])
                test_graphs.extend(block[test_start:test_end])
        else:
            raise NotImplementedError("Wrong OOD experiment type in 3 types")

        print(f"Dataset size: total={N}, train={len(train_graphs)}, test={len(test_graphs)}, OOD type: {args.OOD}")

    elif args.OOD == 'sfo':
        if args.sfo_freq is None:
            raise ValueError("--OOD sfo need --sfo_freq(1..10), such as: --sfo_freq 7 denote 49kHz")
        train_graphs, test_graphs, _ = load_data_fixed_frequency(
            root_dir=args.data_root,
            data_type=args.data_type,
            use_surf=args.use_surf,
            target_freq_index=args.sfo_freq,
            train_num=getattr(args, 'sfo_train_num', None), # if None, choose all datas
            test_num=getattr(args, 'sfo_test_num', None),
            max_workers=8,
            downsample_count=args.downsample_count,
            surf_downsample_count=args.surf_downsample_count,
            normalize=True,
            unit_normalize=True
        )
        print(f"Dataset size: total={len(train_graphs) + len(test_graphs)}, train={len(train_graphs)}, test={len(test_graphs)}, OOD type: {args.OOD}")

    else:
        all_graphs, stats = load_all_data(
            root_dir=args.data_root,
            max_workers=8,
            data_num=args.data_num,
            downsample_count=args.downsample_count,
            surf_downsample_count=args.surf_downsample_count,
            data_type=args.data_type,
            normalize=True,
            unit_normalize=True,
            use_surf=args.use_surf
        )
        N = len(all_graphs)
        split = int((1 - args.test_split) * N)
        train_graphs = all_graphs[:split]
        test_graphs = all_graphs[split:]
        print(f"[Dataset] total={N}, train={len(train_graphs)}, test={len(test_graphs)}  (All Freq)")
    return train_graphs, test_graphs

def get_device(gpu_index: int):
    n_gpu = torch.cuda.device_count()
    use_cuda = 0 <= gpu_index < n_gpu and torch.cuda.is_available()
    device = torch.device(f'cuda:{gpu_index}' if use_cuda else 'cpu')
    print(f"[Device] Using: {device}  (CUDA available={torch.cuda.is_available()}, count={n_gpu})")
    return device

def run_one_model(model_name: str, args, device, train_graphs, test_graphs):
    if model_name == 'DeepONet' and args.out_channels != 1:
        print(f"[Skip] {model_name}: DeepONet only support out_channels=1, Current parameter={args.out_channels}")
        return

    model = MODEL_REGISTRY[model_name](args)
    print(f"[Model] {model_name}, Task (mode) = {args.mode}")

    if args.data_type == 'unstructured_data':
        sub = f'{args.mode}_{CRACK_TYPE[args.crack_type]}_uns' + (f'_OOD_{args.OOD}' if args.OOD else '')
        print("[DataType] Irregular (unstructured)" + (f' OOD Type: {args.OOD}' if args.OOD else ''))
    elif args.data_type == 'structured_data':
        sub = f'{args.mode}_{CRACK_TYPE[args.crack_type]}_s' + (f'_OOD_{args.OOD}' if args.OOD else '')
        print("[DataType] Regular (structured)" + (f' OOD Type: {args.OOD}' if args.OOD else ''))
    else:
        sub = f'{args.mode}'


    save_dir = os.path.join('checkpoints', model_name, sub)
    if args.OOD == 'sfo':
        save_dir = os.path.join(save_dir, str(args.sfo_freq**2) + 'kHz')
    os.makedirs(save_dir, exist_ok=True)

    hparams = {'lr': args.lr, 'batch_size': args.batch_size, 'epochs': args.epochs, 'data_num': args.data_num,
               'data_type':args.data_type, 'crack_type': CRACK_TYPE[args.crack_type], 'mode':args.mode,
               'model_name': model_name, 'in_channels': args.in_channels, 'out_channels': args.out_channels,
               'point_num': args.downsample_count if not args.use_surf else args.surf_downsample_count,
               'OOD': args.OOD}
    if args.OOD == 'sfo':
        hparams.update({'sfo_freq': args.sfo_freq, 'sfo_train_num': args.sfo_train_num, 'sfo_test_num': args.sfo_test_num})

    if args.eval:
        ckpt_path = os.path.join(save_dir, f'{hparams["model_name"]}_{hparams["epochs"]}_{hparams["mode"]}.pth')
        try:
            state = torch.load(ckpt_path, map_location='cpu')
            if isinstance(state, dict):
                model.load_state_dict(state)
                print(f"[Eval] Loaded state_dict from {ckpt_path}")
            else:
                model = state
                print(f"[Eval] Loaded whole model object from {ckpt_path}")
        except Exception as e:
            print(f"[Warn] Could not load checkpoint: {e}")
        train.evaluate(device, model, test_graphs, mode=args.mode)
    else:
        for it in range(args.training_num):
            try:
                model = MODEL_REGISTRY[model_name](args)
                _ = train.main(device, train_graphs, test_graphs, model, hparams, save_dir, mode=args.mode)
                print(f"[Train] {model_name} | iteration {it} finished.")
            except Exception as e:
                print(f"[Error] {model_name} | iteration {it}: {e}")
                continue

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data_root', default='/dataset path/')
    parser.add_argument('--data_type', type=str, default='unstructured_data', choices=['unstructured_data', 'structured_data'])
    parser.add_argument('--crack_type', type=str, default='2', choices=['1', '2', '3', '4', '5', '6'])
    parser.add_argument('--downsample_count', type=int, default=8000, help='downsample points count')
    parser.add_argument('--surf_downsample_count', type=int, default=8000, help='surface downsample points count')
    parser.add_argument('--data_num', type=int, default=100, help='data num')
    parser.add_argument('--test_split', type=float, default=0.2, help='train/test split')
    parser.add_argument('--use_surf', action="store_true", help='use surface data or not')
    parser.add_argument("--OOD", default=None, type=str, help="OOD experiment",
                        choices=[ '', 'high', 'low', 'mid', 'sfo'])

    parser.add_argument("--sfo_freq", default=None, type=int, help="when --OOD sfo: fixed training frequency index (1..10 maps to 1,4,...,100 kHz)",
                        choices=range(1, 11))
    parser.add_argument("--sfo_train_num", default=None, type=int, help="when --OOD sfo: choose train num")
    parser.add_argument("--sfo_test_num", default=None, type=int, help="when --OOD sfo: choose test num")

    parser.add_argument('--mode', type=str, default='T2Q', choices=['T2Q', 'Q2T', 'T2T', 'S2Q', 'S2Q_sp'], help='task type')
    parser.add_argument('--in_channels', type=int, default=13)
    parser.add_argument('--out_channels', type=int, default=1)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--training_num', type=int, default=1, help='repeat times for each model')

    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument("--eval", action="store_true", help="evaluate mode")
    parser.add_argument('--seed', type=int, default=42)

    parser.add_argument('--models', type=str, default='ALL',
                        help='comma-separated model list, e.g., "Transolver,Transolver_plus,FNO3d". '
                             'Use "ALL" to run every registered model.')

    parser.add_argument('--model', default=None, type=str, help='(deprecated) single model name')

    args = parser.parse_args()
    return args


def parse_model_list(args):
    if args.models and args.models.upper() != 'ALL':
        names = [s.strip() for s in args.models.split(',') if s.strip()]
        unknown = [n for n in names if n not in ALL_MODELS]
        if unknown:
            raise ValueError(f"Unknown model(s): {unknown}. Available: {ALL_MODELS}")
        return names

    if args.model:
        if args.model not in ALL_MODELS:
            raise ValueError(f"Unknown model: {args.model}. Available: {ALL_MODELS}")
        return [args.model]

    return ALL_MODELS


def main():
    args = parse_args()
    print(args)

    device = get_device(args.gpu)

    train_graphs, test_graphs = build_datasets(args)

    model_list = parse_model_list(args)
    print(f"[RunList] {model_list}")

    for name in model_list:
        print('='*50 + ' ' + name + ' ' + '='*50)
        run_one_model(name, args, device, train_graphs, test_graphs)
        print('='*50 + ' ' + name + ' ' + 'Done' + ' ' + '='*50)
    print('All Done')


if __name__ == '__main__':
    main()
