import torch
torch.backends.cuda.matmul.allow_tf32 = True
import torch.nn as nn
import transformers
from utils import get_local_dir, get_local_run_dir, disable_dropout, init_distributed, get_open_port, get_lora_model
import os
import hydra
import torch.multiprocessing as mp
from omegaconf import OmegaConf, DictConfig
import trainers
import wandb
import json
import socket
from typing import Optional, Set
import resource
from peft import PeftModel


OmegaConf.register_new_resolver("get_local_run_dir", lambda exp_name, local_dirs, create_dir: get_local_run_dir(exp_name, local_dirs, create_dir))


def worker_main(rank: int, world_size: int, config: DictConfig, policy: nn.Module, reference_model: Optional[nn.Module] = None):
    """Main function for each worker process (may be only 1 for BasicTrainer/TensorParallelTrainer)."""
    if 'FSDP' in config.trainer:
        init_distributed(rank, world_size, port=config.fsdp_port)
    
    if config.debug:
        wandb.init = lambda *args, **kwargs: None
        wandb.log = lambda *args, **kwargs: None

    if rank == 0 and config.wandb.enabled:
        os.environ['WANDB_CACHE_DIR'] = get_local_dir(config.local_dirs)
        wandb.init(
            entity=config.wandb.entity,
            project=config.wandb.project,
            config=OmegaConf.to_container(config),
            dir=get_local_dir(config.local_dirs),
            name=config.exp_name,
        )

    TrainerClass = getattr(trainers, config.trainer)
    print(f'Creating trainer on process {rank} with world size {world_size}')
    trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size)

    trainer.train()
    # if config.create_dir:
        # trainer.save()


@hydra.main(version_base=None, config_path="config", config_name="config")
def main(config: DictConfig):
    """Main entry point for training. Validates config, creates/initializes model(s), and kicks off worker process(es)."""

    # Resolve hydra references, e.g. so we don't re-compute the run directory
    OmegaConf.resolve(config)

    missing_keys: Set[str] = OmegaConf.missing_keys(config)
    if missing_keys:
        raise ValueError(f"Got missing keys in config:\n{missing_keys}")

    if config.eval_every % config.batch_size != 0:
        print('WARNING: eval_every must be divisible by batch_size')
        print('Setting eval_every to', config.eval_every - config.eval_every % config.batch_size)
        config.eval_every = config.eval_every - config.eval_every % config.batch_size

    if 'FSDP' in config.trainer and config.fsdp_port is None:
        free_port = get_open_port()
        print('no FSDP port specified; using open port for FSDP:', free_port)
        config.fsdp_port = free_port

    print(OmegaConf.to_yaml(config))

    config_path = os.path.join(config.local_run_dir, 'config.yaml')
    if config.create_dir:
        with open(config_path, 'w') as f:
            OmegaConf.save(config, f)

    print('=' * 80)
    print(f'Writing to {socket.gethostname()}:{config.local_run_dir}')
    print('=' * 80)
 
    os.environ['XDG_CACHE_HOME'] = get_local_dir(config.local_dirs)
    print('building policy')
    if config.lora.enable and config.lora.two_device:
        model_kwargs = {}
    else:
        model_kwargs = {'device_map': 'balanced'} if config.trainer in ['BasicTrainer', 'BasicEvaluator'] else {}
    policy_dtype = getattr(torch, config.model.policy_dtype)
    
    mode = 'merged'

    if mode == 'merged':
        model_path_ = config.model.name_or_path if '.pt' in config.model.archive else config.model.archive
        policy = transformers.AutoModelForCausalLM.from_pretrained(
            model_path_, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True, torch_dtype=policy_dtype, **model_kwargs)
        if '.pt' in config.model.archive:
            policy.load_state_dict(torch.load(config.model.archive, map_location='cpu')['state'])
        
        if config.lora.two_device:
            device_ids = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
            policy.to(torch.device(f'cuda:{device_ids[0]}'))
        disable_dropout(policy)

        if config.lora.enable:
            policy = get_lora_model(policy, config)

    else:
        policy = transformers.AutoModelForCausalLM.from_pretrained(
            config.model.name_or_path, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True, torch_dtype=policy_dtype, **model_kwargs)
        
        policy = PeftModel.from_pretrained(policy, config.model.archive, is_trainable=True)

        if config.lora.two_device:
            device_ids = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
            policy.to(torch.device(f'cuda:{device_ids[0]}'))
        disable_dropout(policy)

    print('building reference model')
    reference_model_dtype = getattr(torch, config.model.reference_dtype)
    model_path_ = config.model.name_or_path if '.pt' in config.model.archive_reference else config.model.archive_reference
    reference_model = transformers.AutoModelForCausalLM.from_pretrained(
        model_path_, cache_dir=get_local_dir(config.local_dirs), low_cpu_mem_usage=True, torch_dtype=reference_model_dtype, **model_kwargs)
    if '.pt' in config.model.archive_reference:
        reference_model.load_state_dict(torch.load(config.model.archive_reference, map_location='cpu')['state'])
    
    if config.lora.two_device:
        reference_model.to(torch.device(f'cuda:{device_ids[1]}'))
    disable_dropout(reference_model)
    


    if 'FSDP' in config.trainer:
        world_size = torch.cuda.device_count()
        print('starting', world_size, 'processes for FSDP training')
        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
        resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
        print(f'setting RLIMIT_NOFILE soft limit to {hard} from {soft}')
        mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True)
    else:
        print('starting single-process worker')
        worker_main(0, 1, config, policy, reference_model)


if __name__ == '__main__':
    main()