import os
import argparse
from functools import partial

import wandb
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool

from models import EGNNModel, HEGNNModel, TFNModel, MACEModel, FastEGNNModel, EquiformerModel
from models import EGNNModel_cpl_global, EGNNModel_cpl_local
from models import TFNModel_cpl_global, TFNModel_cpl_local
from utils import fix_seed, run_experiment
from datasets import TetrahedronDataset


# os.environ['WANDB_MODE'] = 'offline'

parser=argparse.ArgumentParser(description='ICML25-Tet')

# Model
parser.add_argument('--exp_name', type=str, default='simple-exp', help='str type, name of the experiment (default: simple_exp)')
parser.add_argument('--model', type=str, default='EGNN', help='which model (default: EGNN)')
parser.add_argument('--hidden_dim', type=int, default=64, help='hidden_dim (default: 64)')
parser.add_argument('--num_layer', type=int, default=4, help='number of layers of gnn (default: 4)')


# Data
parser.add_argument('--data_directory', type=str, required=True, help='data directory (required)')
parser.add_argument('--label_type', type=str, required=True, help='label type to predict (required)')
parser.add_argument('--max_train_samples', type=int, default=1e8, help='maximum amount of train samples (default: 1e8)')
parser.add_argument('--max_test_samples', type=int, default=1e8, help='maximum amount of valid and test samples (default: 1e8)')


# Training
parser.add_argument('--seed', type=int, default=43, help='random seed (default: 43)')
parser.add_argument('--epochs', type=int, default=300, help='epochs (default: 300)')
parser.add_argument('--batch_size', type=int, default=100, help='int type, batch size for training (default: 256)')
parser.add_argument('--learning_rate', type=float, default=5e-4, help='learningrate (lr) of optimizer (default: 5e-4)')
parser.add_argument('--weight_decay', type=float, default=1e-12, help='weightdecay of optimizer (default: 1e-12)')
parser.add_argument('--times', type=int, default=1, help='experiment repeat times (default: 1)')
parser.add_argument('--early_stop', type=int, default=100, help='early stop (default: 100)')


# Log
parser.add_argument('--log_directory', type=str, default='./logs/nbody', help='directory to generatethe json log file (default: ./logs)')
parser.add_argument('--eval_interval', type=int, default=5, help='how many epochs to wait before logging eval (default: 5)')


# Device
parser.add_argument('--device', type=str, default='cpu', help='device (default: cpu)')


args=parser.parse_args()

class GlobalModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.main_model = model
    
    def forward(self, data: Data) -> torch.Tensor:
        node_pos = model(data)
        return global_mean_pool(node_pos, data.batch)

if __name__ == '__main__':
    fix_seed(seed=args.seed)
    dataset = partial(TetrahedronDataset, data_dir=args.data_directory, label_type=args.label_type, device='cpu')
    dataset_train = dataset(max_samples=args.max_train_samples, partition='train')
    dataset_valid = dataset(max_samples=args.max_test_samples,  partition='valid')
    dataset_test  = dataset(max_samples=args.max_test_samples,  partition='test')

    loader = partial(DataLoader, batch_size=args.batch_size, drop_last=False, num_workers=4)
    train_loader = loader(dataset=dataset_train, shuffle=True)
    valid_loader = loader(dataset=dataset_valid, shuffle=False)
    test_loader = loader(dataset=dataset_test,  shuffle=False)
    
    
    wandb.init(project=f'ICML25-Tet', name=f'{args.model}-{args.label_type}-{args.num_layer}layer')
    
    model = {
        'EGNN' : EGNNModel,
        'HEGNN': partial(HEGNNModel, max_ell=2),
        'TFN'  : partial(TFNModel, max_ell=2, require_vel=False),
        'MACE' : partial(MACEModel, max_ell=2, correlation=4, require_vel=False),
        'FastEGNN': partial(FastEGNNModel),
        'Equiformer': partial(EquiformerModel),
        'EGNN_cpl_global': partial(EGNNModel_cpl_global),
        'EGNN_cpl_local' : partial(EGNNModel_cpl_local),
        'TFN_cpl_global' : partial(TFNModel_cpl_global, max_ell=2, require_vel=False),
        'TFN_cpl_local'  : partial(TFNModel_cpl_local, max_ell=2),
    }[args.model](num_layer=args.num_layer, hidden_dim=args.hidden_dim, node_input_dim=1, edge_attr_dim=1, device=args.device)
    global_model = GlobalModel(model)
    
    loss_func = MSELoss()
    optimizer = Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    
    run_experiment(
        model=global_model, 
        train_loader=train_loader, 
        valid_loader=valid_loader, 
        test_loader=test_loader,
        num_epochs=args.epochs, 
        optimizer=optimizer, 
        loss_func=loss_func,
        eval_interval=args.eval_interval,
        early_stop=int(1e6), 
        device=args.device,
    )
    
    wandb.finish()