from AbstractModels.Config import Config
from AbstractModels.Model import Model
from AbstractModels.Trainer import Trainer

import torch
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group

import numpy as np
import os

import random

from datetime import timedelta

def main() -> None:
    config: Config = Config()

    seed = config.get_seed()

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # torch.use_deterministic_algorithms(True)
    torch.set_float32_matmul_precision("high")

    np.random.seed(seed)
    random.seed(seed)

    if config.use_ddp():
        distributed_train(config)
    elif config.validate():
        validate(config.initialize_model(), config)
    else:
        train(config.initialize_model(), config)

def train(model: Model, config: Config) -> None:
    optimizer = config.get_optimizer()
    scheduler = config.get_scheduler(optimizer=optimizer)
    
    trainer = Trainer(model)

    trainer.fit(
        epochs=config.get_epochs(),
        train_loader=config.get_dataloader(train=True), 
        val_loader=config.get_dataloader(train=False), 
        optimizer=optimizer, 
        criterion=config.get_criterion(),
        model_dir=config.get_model_dir(),
        config=config,
        scheduler=scheduler,
    )

def validate(model: Model, config: Config) -> tuple[float, float, float, float]:    
    trainer = Trainer(model)

    return trainer.validate(
        val_loader=config.get_dataloader(train=False),
        criterion=config.get_criterion(),
        config=config
    )

def distributed_train(config: Config) -> None:
    world_size = torch.cuda.device_count()
    print(f"Training on {world_size} GPUs")
    # mp.set_start_method('fork', force=True)
    mp.spawn(_distributed_train, args=(world_size, config), nprocs=world_size)

def ddp_setup(rank, world_size):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    torch.cuda.set_device(rank)
    init_process_group(backend="nccl", rank=rank, world_size=world_size, timeout=timedelta(minutes=30))

def seed_ddp(rank: int, seed: int) -> None:
    # Set per-process seed (DDP)
    seed = seed + rank  # Make each rank have a unique but deterministic seed

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_float32_matmul_precision("high")


def _distributed_train(rank: int, world_size: int, config: Config) -> None:
    ddp_setup(rank, world_size)

    seed_ddp(rank, config.get_seed())

    model: Model = config.initialize_model(rank)
    optimizer = config.get_optimizer()
    scheduler = config.get_scheduler(optimizer=optimizer)
    
    trainer = Trainer(model)

    trainer.fit(
        epochs=config.get_epochs(),
        train_loader=config.get_dataloader(train=True, distributed=True, rank=rank, world_size=world_size), 
        val_loader=config.get_dataloader(train=False, distributed=True, rank=rank, world_size=world_size), 
        optimizer=optimizer, 
        criterion=config.get_criterion(),
        model_dir=config.get_model_dir(),
        config=config,
        scheduler=scheduler,
        gradient_steps=config.get_gradient_steps(),
        gpu_id=rank,
        distributed_training=True
    )

    destroy_process_group()

if __name__ == '__main__':
    main()