import torch
torch.backends.cuda.matmul.allow_tf32 = True
import torch.nn as nn
import transformers
from utils import get_local_dir, get_local_run_dir, disable_dropout, init_distributed, get_open_port
import os
import hydra
import torch.multiprocessing as mp
from omegaconf import OmegaConf, DictConfig
import trainers
import wandb
import json
import socket
from typing import Optional, Set
import resource

OmegaConf.register_new_resolver("get_local_run_dir", lambda exp_name, local_dirs: get_local_run_dir(exp_name, local_dirs))

@hydra.main(version_base=None, config_path="../config", config_name="config")
def main(config: DictConfig):
    """Main entry point for training. Validates config, creates/initializes model(s), and kicks off worker process(es)."""

    # Resolve, record, and load configurations. To add new configurations, modify the .yaml files in the config folder
    OmegaConf.resolve(config)
    missing_keys: Set[str] = OmegaConf.missing_keys(config)

    if missing_keys:
        raise ValueError(f"Got missing keys in config:\n{missing_keys}")
    if config.eval_every % config.batch_size != 0:
        print('WARNING: eval_every must be divisible by batch_size')
        print('Setting eval_every to', config.eval_every - config.eval_every % config.batch_size)
        config.eval_every = config.eval_every - config.eval_every % config.batch_size
    if 'FSDP' in config.trainer and config.fsdp_port is None:
        free_port = get_open_port()
        print('no FSDP port specified; using open port for FSDP:', free_port)
        config.fsdp_port = free_port
    print(OmegaConf.to_yaml(config))
    config_path = os.path.join(config.local_run_dir, 'config.yaml')
    with open(config_path, 'w') as f:
        OmegaConf.save(config, f)

    print('=' * 80)
    print(f'Writing to {socket.gethostname()}:{config.local_run_dir}')
    print('=' * 80)

    # Load pretrained model from hugging face
    os.environ['XDG_CACHE_HOME'] = get_local_dir(config.local_dirs)
    print('building policy')
    model_kwargs = {'device_map': 'balanced'}
    policy_dtype = getattr(torch, config.model.policy_dtype)
    policy = None
    if config.loss.name not in {'reward_loss'}:
        print(f"config.model.name_or_path: {config.model.name_or_path}")
        policy = transformers.AutoModelForCausalLM.from_pretrained(
            config.model.name_or_path, 
            cache_dir=get_local_dir(config.local_dirs), 
            torch_dtype=policy_dtype
        )
        disable_dropout(policy)
        if config.enable_gradient_checkpoint:
            policy.gradient_checkpointing_enable()

    value_model = None
    if 'reward_model' in config and config.reward_model is not None:
        print(f"config.reward_model.name_or_path: {config.reward_model.name_or_path}")
        reward_model = transformers.AutoModelForSequenceClassification.from_pretrained(
            config.reward_model.name_or_path, cache_dir=get_local_dir(config.local_dirs), torch_dtype=policy_dtype, num_labels=1
        )
        disable_dropout(reward_model) 
        if config.loss.name in {'ppo'}:
            value_model = transformers.AutoModelForSequenceClassification.from_pretrained(
                config.reward_model.name_or_path, cache_dir=get_local_dir(config.local_dirs), torch_dtype=policy_dtype, num_labels=1
            )
            disable_dropout(value_model)
            if config.enable_gradient_checkpoint:
                value_model.gradient_checkpointing_enable()
        # else:
        #     value_model = None
    else:
        reward_model = None

    # Reference_model represent the model to generate the answers in the dataset
    if config.loss.name in {'dpo', 'ipo', 'advdpo', 'ppo'}:
        print('building reference model')
        reference_model_dtype = getattr(torch, config.model.reference_dtype)
        reference_model = transformers.AutoModelForCausalLM.from_pretrained(
            config.model.name_or_path, cache_dir=get_local_dir(config.local_dirs), torch_dtype=reference_model_dtype)
        disable_dropout(reference_model)
    else:
        reference_model = None

    # Load finetuned model, e.g, the starting point for rlhf training
    if config.model.archive is not None:
        state_dict = torch.load(config.model.archive, map_location='cpu')
        step, metrics = state_dict['step_idx'], state_dict['metrics']
        print(f'loading pre-trained weights at step {step} from {config.model.archive} with metrics {json.dumps(metrics, indent=2)}')
        policy.load_state_dict(state_dict['state'], strict=False)
        if config.loss.name in {'dpo', 'ipo', 'ppo', 'advdpo'}:
            reference_model.load_state_dict(state_dict['state'], strict=False)
        print('loaded pre-trained weights')
    
    if 'reward_model' in config and config.reward_model.archive is not None:
        state_dict = torch.load(config.reward_model.archive, map_location='cpu')
        step, metrics = state_dict['step_idx'], state_dict['metrics']
        print(f'loading pre-trained weights at step {step} from {config.reward_model.archive} with metrics {json.dumps(metrics, indent=2)}')
        missing_keys, unexpected_keys = reward_model.load_state_dict(state_dict['state'], strict=False)
        if config.loss.name in {'ppo'}:
            missing_keys, unexpected_keys = value_model.load_state_dict(state_dict['state'], strict=False)
        print('loaded pre-trained weights')

    # Start training. The dataset is loaded in later training code, and the information about the dataset is contained in config
    if 'FSDP' in config.trainer:
        world_size = torch.cuda.device_count()
        print('starting', world_size, 'processes for FSDP training')
        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
        resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
        print(f'setting RLIMIT_NOFILE soft limit to {hard} from {soft}')
        mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model, reward_model, value_model), join=True)
    else:
        print('starting single-process worker')
        worker_main(0, 1, config, policy, reference_model, reward_model, value_model)

def worker_main(rank: int, world_size: int, config: DictConfig, policy: nn.Module, reference_model: Optional[nn.Module] = None, reward_model: Optional[nn.Module] = None, value_model: Optional[nn.Module] = None):
    """Main function for each worker process (may be only 1 for BasicTrainer/TensorParallelTrainer)."""
    if 'FSDP' in config.trainer:
        init_distributed(rank, world_size, port=config.fsdp_port)
        config.is_distributed = True

    config.trainer = config.trainer.replace('FSDP.', '')

    #Setup wandb
    if config.debug:
        wandb.init = lambda *args, **kwargs: None
        wandb.log = lambda *args, **kwargs: None

    if rank == 0 and config.wandb.enabled:
        os.environ['WANDB_CACHE_DIR'] = get_local_dir(config.local_dirs)
        wandb.init(
            entity=config.wandb.entity,
            project=config.wandb.project,
            config=OmegaConf.to_container(config),
            dir=get_local_dir(config.local_dirs),
            name=config.exp_name,
        )

    #Choose the trainer which represents the training method
    TrainerClass = getattr(trainers, config.trainer)
    print(f'Creating trainer on process {rank} with world size {world_size}')
    if issubclass(TrainerClass, trainers.PPOTrainer):
        print("initializing PPOTrainer")
        trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, reward_model=reward_model, value_model=value_model, rank=rank, world_size=world_size)
    elif issubclass(TrainerClass, trainers.RewardTrainer):
        print("initializing RewardTrainer")
        trainer = TrainerClass(reward_model, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size)
    elif issubclass(TrainerClass, trainers.AdversarialBONTrainer):
        print("initializing AdversarialBONTrainer")
        trainer = TrainerClass(reward_model, config, config.seed, config.local_run_dir, reference_model=reference_model, policy_ref=policy, rank=rank, world_size=world_size)
    elif issubclass(TrainerClass, trainers.AdversarialDPOTrainer):
        print("initializing AdversarialDPOTrainer")
        trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size)
    elif issubclass(TrainerClass, trainers.AdversarialTrainer):
        print("initialzing AdversarialTrainer")
        trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, reward_model=reward_model, rank=rank, world_size=world_size)
    elif issubclass(TrainerClass, trainers.BasicTrainer):
        print("initializing BasicTrainer")
        trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size)
    else:
        raise ValueError("Trainer not implemented yet") 

    #start training
    trainer.train()
    trainer.save()

if __name__ == '__main__':
    main()