from train import train
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'models'))
 

import os
print(f"当前运行的脚本路径: {os.path.abspath(__file__)}")

sys.path.append(os.path.dirname(__file__))  
from utils.model_factory import ModelFactoryp

import json
import copy
from datetime import datetime
import random
import argparse
import numpy as np
import csv
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
plt.rcParams["animation.html"] = "jshtml"
cpu_num = 4 
os.environ["OMP_NUM_THREADS"] = str(cpu_num) 
os.environ["MKL_NUM_THREADS"] = str(cpu_num) 
torch.set_num_threads(cpu_num)


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     os.environ['PYTHONHASHSEED'] = str(seed) 
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.benchmark = False
     torch.backends.cudnn.enabled = True

def save_results_to_csv(config: dict, results: dict, filename='1results.csv'):
    row = {**config, **results}
    file_exists = os.path.isfile(filename)
    with open(filename, mode='a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=row.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(row)

def count_parameters(model):
    total = 0
    for p in model.parameters():
        if p.requires_grad:
            numel = p.numel()
            if p.is_complex():
                total += 2 * numel
            else:
                total += numel
    print(f'Total trainable parameters (real-equivalent): {total:,} ({total / 1e6:.2f}M)')
    return total

def main(cfg):
    

    config_path = os.path.join(os.path.dirname(__file__), 'configs', 'model_registry.json')
    model_factory = ModelFactoryp(config_path)
    
    model_version = cfg.model_version
    nu, f = cfg.nu, cfg.f
    
    if not os.path.exists(f'log'):
        os.mkdir(f'log')
    if not os.path.exists(f'log/log_model_{model_version}_nu_{nu}_f_{f}'):
        os.mkdir(f'log/log_model_{model_version}_nu_{nu}_f_{f}')
    if not os.path.exists(f'model'):
        os.mkdir(f'model')
    
    setup_seed(cfg.seed)
    dateTimeObj = datetime.now()
    
    timestring = f'{dateTimeObj.date().month}_{dateTimeObj.date().day}_{dateTimeObj.time().hour}_{dateTimeObj.time().minute}_{dateTimeObj.time().second}'
    
    cfg.model_name = f'seed-{cfg.seed}-{cfg.model_version}-lr-{cfg.lr}-channel-{cfg.channel}-modes-{cfg.modes}-n_layers-{cfg.n_layers}-dt-{cfg.dt}-step_forward-{cfg.step_forward}-{timestring}'
    logfile = f'log/log_model_{model_version}_nu_{nu}_f_{f}/seed-{cfg.seed}-{cfg.model_version}-lr-{cfg.lr}-channel-{cfg.channel}-modes-{cfg.modes}-n_layers-{cfg.n_layers}-dt-{cfg.dt}-step_forward-{cfg.step_forward}-{timestring}.txt'
    sys.stdout = open(logfile, 'w')
    
    print('--------args----------')
    print(f"收到的命令行指令: {' '.join(sys.argv)}")
    for k in list(vars(cfg).keys()):
        print('%s: %s' % (k, vars(cfg)[k]))
    print('--------args----------\n')
    

    try:
        print(f"正在创建模型: {model_version}")
        
        model_kwargs = {}

        if model_version == 'FNO' or model_version == 'SINO-FNO':
            model_kwargs.update({
                'modes1': cfg.modes,
                'modes2': cfg.modes,
                'width': cfg.channel
            })


        print(f"模型创建: {model_version}")
        net = model_factory.create_model(model_version, **model_kwargs)
        print(f"模型创建成功: {model_version}")
        
        
    except Exception as e:
        print(f"创建模型失败: {e}")
        import traceback
        traceback.print_exc()
        return
    
    param_count = count_parameters(net)
    sys.stdout.flush()
    print(param_count )
    val_error, test_error, train_error = train(cfg, net)

    cfg_dict = vars(cfg)
    results = {
        'param_count': param_count,
        'val_error': val_error,
        'test_error': test_error,
        'train_error': train_error,
    }
    save_results_to_csv(cfg_dict, results)

    final_time = datetime.now()
    print(f'FINAL_TIME_{final_time.date().month}_{final_time.date().day}_{final_time.time().hour}_{final_time.time().minute}_{final_time.time().second}')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Hyper-parameters of training')
    

    parser.add_argument('--data_path', type=str, 
                        default='/data/',
                        help='path of data') 

    parser.add_argument('--nu', type=float, default=1e-4,#1e-5
            help='nu')

    parser.add_argument('--f', type=str, default='li',#kf
            help='forcing')

    parser.add_argument('--model_name', type=str, default='',
            help='model name')

    parser.add_argument('--device', type=str, default='cuda:0',
                        help='Used device')
    
    parser.add_argument('--seed', type=int, default=0,
            help='seed')
    
    # model hyper-parameter
    parser.add_argument('--model_version', type=str, default='FNO', 
            help='model_version')
    
    parser.add_argument('--channel', type=int, default=12,
            help='channel')
    
    parser.add_argument('--modes', type=int, default=12,
            help='modes')
    
    parser.add_argument('--dt', type=float, default=1,
            help='dt, 1/dt should be int')

    parser.add_argument('--step_forward', type=str, default='rk4',
            help='step_forward')

    parser.add_argument('--n_layers', type=int, default=1,
            help='n_layers')
    
    # hyper-parameter during training
    parser.add_argument('--num_train', type=int, default=5,
            help='number of training data')
            
    parser.add_argument('--num_step1', type=int, default=2,
            help='num_step1')

    parser.add_argument('--num_step2', type=int, default=2,
            help='num_step2')

    parser.add_argument('--lr', type=float, default=0.001,
            help='lr of optim')

    parser.add_argument('--weight_decay', type=float, default=0.00,
            help='lr of optim')

    parser.add_argument('--num_iterations', type=int, default=20000, 
            help='num_iterations of optim')

    cfg = parser.parse_args()
    

    main(cfg)