import train
import os
import torch
import argparse

from dataset import load_all_data, load_data_fixed_frequency
from model.Transolver.Transolver_Irregular_3D import Model as Transolver
from model.Transolver.Transolver_plus import Model as Transolver_plus
from model.Transolver.benchmark.Transolver_Structured_Mesh_3D import Model as Transolver_Structured_Mesh_3D
from model.GeoFNO.FNO3d import FNO3d
from model.GeoFNO.GNO import GNO
from model.GeoFNO.GeoFNO import GeoFNO
from model.mlp import MLP
from model.GeoFNO.FFNO import FNOFactorizedMesh3D as FFNO
from model.GeoFNO.FCNO import CNOFactorizedMesh3D as FCNO
from model.LNO.LNO import LNO, LNO_single, LNO_triple
from model.DeepONet.deeponet import DeepONet
from model.FEM import FEMHeatSolver

def get_model(args):
    if args.model == 'Transolver':
        model = Transolver(n_hidden=256, n_layers=8, space_dim=3,
                           fun_dim=args.in_channels,
                           n_head=8,
                           mlp_ratio=2, out_dim=args.out_channels,
                           slice_num=32,
                           # dropout=0.2,
                           unified_pos=0)
    elif args.model == 'Transolver_R':
        model = Transolver_Structured_Mesh_3D(n_hidden=256, n_layers=8, space_dim=3,
                                              fun_dim=args.in_channels,
                                              n_head=8,
                                              mlp_ratio=2, out_dim=args.out_channels,
                                              slice_num=32,
                                              H=20, W=20, D=20,
                                              # dropout=0.2,
                                              unified_pos=0)
    elif args.model == 'Transolver_plus':
        model = Transolver_plus(n_hidden=256, n_layers=8, space_dim=3,
                                              fun_dim=args.in_channels,
                                              n_head=8,
                                              mlp_ratio=2, out_dim=args.out_channels,
                                              slice_num=32,
                                              # dropout=0.2,
                                              unified_pos=0)
    elif args.model == 'FNO3d':
        model = FNO3d(modes1=12, modes2=12, modes3=8, width=32, in_channels=args.in_channels,
                      out_channels=args.out_channels, H=20, W=20, D=20)
    elif args.model == 'GNO':
        model = GNO(width=32, in_channel=args.in_channels, out_channel=args.out_channels, r=1e-8)
    elif args.model == 'GeoFNO':
        model = GeoFNO(modes1=12, modes2=12, modes3=8, width=32, in_channels=args.in_channels,
                       out_channels=args.out_channels, s=20)
    elif args.model == 'MLP':
        model = MLP(in_channels=args.in_channels, out_channels=args.out_channels, hidden_channels=32, n_layers=4,
                    n_dim=1)
    elif args.model == 'FFNO':
        model = FFNO(modes_x=12, modes_y=12, modes_z=8, input_dim=args.in_channels, output_dim=args.out_channels,
                     width=32, n_layers=4,
                     share_weight=False, factor=4, n_ff_layers=2, ff_weight_norm=True, layer_norm=False, H=20, W=20,
                     D=20)
    elif args.model == 'FFNO-share':
        model = FFNO(modes_x=12, modes_y=12, modes_z=8, input_dim=args.in_channels, output_dim=args.out_channels,
                     width=32, n_layers=4,
                     share_weight=True, factor=4, n_ff_layers=2, ff_weight_norm=True, layer_norm=False, H=20, W=20,
                     D=20)
    elif args.model == 'FCNO':
        model = FCNO(modes_x=12, modes_y=12, modes_z=8, input_dim=args.in_channels, output_dim=args.out_channels,
                     width=32, n_layers=4,
                     share_weight=False, factor=4, n_ff_layers=2, ff_weight_norm=True, layer_norm=False, H=20, W=20,
                     D=20)
    elif args.model == 'LNO':
        model_attr = dict()
        model_attr['time'] = False
        model = LNO(x_dim=3, y1_dim=args.in_channels, y2_dim=args.out_channels, n_block=4, n_mode=256, n_dim=128,
                    n_head=8, n_layer=2, attn='Attention_Vanilla', act='GELU', model_attr=model_attr)
    elif args.model == 'LNO_single':
        model_attr = dict()
        model_attr['time'] = False
        model = LNO_single(x_dim=None, y1_dim=args.in_channels, y2_dim=args.out_channels, n_block=4, n_mode=256, n_dim=128,
                           n_head=8, n_layer=2, attn='Attention_Vanilla', act='GELU', model_attr=model_attr)
    elif args.model == 'DeepONet':  # Output can only be 1
        model = DeepONet(branch_dim=3, trunk_dim=args.in_channels, branch_depth=2, trunk_depth=3, width=32)
    elif args.model == 'FEM':
        model = FEMHeatSolver(mode=args.mode, num_time_steps=args.out_channels, num_points=args.downsample_count)
    else:
        raise NotImplementedError("No model type found")
    return model


CRACK_TYPE = {
    '1': 'single',
    '2': 'I-double',
    '3': 'II-double',
    '4': 'III-double',
    '5': 'I-multi',
    '6': 'II-multi',
}

parser = argparse.ArgumentParser()
parser.add_argument('--data_root', default='/dataset path/')
parser.add_argument('--data_type', type=str,
                    default='unstructured_data')
parser.add_argument('--crack_type', type=str, default='2', choices=['1', '2', '3', '4', '5', '6'])
parser.add_argument('--downsample_count', type=int, default=8000,
                    help='downsample points count')
parser.add_argument('--surf_downsample_count', type=int, default=8000,
                    help=' surface downsample points count')
parser.add_argument('--data_num', type=int, default=100,
                    help='data num')
parser.add_argument('--test_split', type=float, default=0.2, help='train/test split')
parser.add_argument('--gpu', default=0, type=int)

parser.add_argument('--model', default='Transolver', type=str)
parser.add_argument('--lr', default=0.0001, type=float)
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--epochs', default=200, type=int)
parser.add_argument('--in_channels', default=13, type=int)
parser.add_argument('--out_channels', default=1, type=int)


parser.add_argument("--OOD", default=None, type=str, help="OOD experiment", choices=[ '', 'high', 'low', 'mid', 'sfo'])
parser.add_argument("--eval", action="store_true", help="evaluate model or not")
parser.add_argument('--use_surf', action="store_true", help="use surface data or not")
parser.add_argument("--mode", type=str, default='T2Q', choices=['T2Q', 'Q2T', 'T2T', 'S2Q', 'S2Q_sp'], help='task type')
parser.add_argument('--training_num', type=int, default=1, help="The number of training sessions")

args = parser.parse_args()
print(args)

hparams = {'lr': args.lr, 'batch_size': args.batch_size, 'epochs': args.epochs}

n_gpu = torch.cuda.device_count()
use_cuda = 0 <= args.gpu < n_gpu and torch.cuda.is_available()
device = torch.device(f'cuda:{args.gpu}' if use_cuda else 'cpu')

print("Loading VTU data...")
if args.OOD is not None:
    print("Loading VTU data once ...")
    if args.OOD in ['high', 'mid', 'low']:
        all_graphs, stats = load_all_data(
            root_dir=args.data_root,
            max_workers=1,  # OOD: splitting by blocks
            data_num=args.data_num,
            downsample_count=args.downsample_count,
            surf_downsample_count=args.surf_downsample_count,
            data_type=args.data_type,
            normalize=True,
            unit_normalize=True,
            use_surf=args.use_surf
        )
        N = len(all_graphs)
        train_graphs, test_graphs = [], []
        block_length = 10
        if args.OOD == 'high':
            split = int((1 - args.test_split) * block_length)
            for i in range(0, len(all_graphs), block_length):
                block = all_graphs[i:i + 10]
                if len(block) < 10:
                    # Ignore incomplete data blocks
                    continue
                train_graphs.extend(block[:split])
                test_graphs.extend(block[split:])
        elif args.OOD == 'low':
            split = int(args.test_split * block_length)
            for i in range(0, len(all_graphs), block_length):
                block = all_graphs[i:i + block_length]
                if len(block) < block_length:
                    # Ignore incomplete data blocks
                    continue
                train_graphs.extend(block[split:])
                test_graphs.extend(block[:split])
        elif args.OOD == 'mid':
            center = block_length // 2
            split = int(args.test_split * block_length / 2)
            for i in range(0, len(all_graphs), block_length):
                block = all_graphs[i:i + block_length]
                if len(block) < block_length:
                    # Ignore incomplete data blocks
                    continue
                test_start = center - split
                test_end = center + split

                train_graphs.extend(block[:test_start] + block[test_end:])
                test_graphs.extend(block[test_start:test_end])
        else:
            raise NotImplementedError("Wrong OOD experiment type in 3 types")

        print(f"Dataset size: total={N}, train={len(train_graphs)}, test={len(test_graphs)}, OOD type: {args.OOD}")

    elif args.OOD == 'sfo':
        if args.sfo_freq is None:
            raise ValueError("--OOD sfo need --sfo_freq(1..10), such as: --sfo_freq 7 denote 49kHz")
        train_graphs, test_graphs, _ = load_data_fixed_frequency(
            root_dir=args.data_root,
            data_type=args.data_type,
            use_surf=args.use_surf,
            target_freq_index=args.sfo_freq,
            train_num=getattr(args, 'sfo_train_num', None),  # if None, choose all datas
            test_num=getattr(args, 'sfo_test_num', None),
            max_workers=8,
            downsample_count=args.downsample_count,
            surf_downsample_count=args.surf_downsample_count,
            normalize=True,
            unit_normalize=True
        )
        print(
            f"Dataset size: total={len(train_graphs) + len(test_graphs)}, train={len(train_graphs)}, test={len(test_graphs)}, OOD type: {args.OOD}")

    else:
        all_graphs, stats = load_all_data(
            root_dir=args.data_root,
            max_workers=8,
            data_num=args.data_num,
            downsample_count=args.downsample_count,
            surf_downsample_count=args.surf_downsample_count,
            data_type=args.data_type,
            normalize=True,
            unit_normalize=True,
            use_surf=args.use_surf
        )
        N = len(all_graphs)
        split = int((1 - args.test_split) * N)
        train_graphs = all_graphs[:split]
        test_graphs = all_graphs[split:]
        print(f"[Dataset] total={N}, train={len(train_graphs)}, test={len(test_graphs)}  (All Freq)")


model = get_model(args)
print(f'Model: {args.model}, Training Type:{args.mode}')

if args.data_type == 'unstructured_data':
    sub = f'{args.mode}_{CRACK_TYPE[args.crack_type]}_uns' + (f'_OOD_{args.OOD}' if args.OOD else '')
    print("[DataType] Irregular (unstructured)" + (f' OOD Type: {args.OOD}' if args.OOD else ''))
elif args.data_type == 'structured_data':
    sub = f'{args.mode}_{CRACK_TYPE[args.crack_type]}_s' + (f'_OOD_{args.OOD}' if args.OOD else '')
    print("[DataType] Regular (structured)" + (f' OOD Type: {args.OOD}' if args.OOD else ''))
else:
    sub = f'{args.mode}'

save_dir = os.path.join('checkpoints', args.model, sub)
os.makedirs(save_dir, exist_ok=True)

if args.eval:
    ckpt_path = os.path.join(save_dir, f'{hparams["model_name"]}_{hparams["epochs"]}_{hparams["mode"]}.pth')
    try:
        state = torch.load(ckpt_path, map_location='cpu')
        if isinstance(state, dict):
            model.load_state_dict(state)
            print(f"[Eval] Loaded state_dict from {ckpt_path}")
        else:
            model = state
            print(f"[Eval] Loaded whole model object from {ckpt_path}")
    except  Exception as e:
        print(f"[Warn] Could not load checkpoint: {e}")
    train.evaluate(device, model, test_graphs, mode=args.mode)
else:
    for i in range(args.training_num):
        try:
            model = get_model(args)
            _ = train.main(device, train_graphs, test_graphs, model, hparams, save_dir, mode=args.mode)
            print(f"[Train] {args.model} | iteration {i} finished.")
        except Exception as e:
            print(f"[Error] {args.model} | iteration {i}: {e}")
            continue