import torch
import torch.distributed as dist
import os
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
import torch.distributed._shard.checkpoint as dist_cp
import torch.distributed as dist

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM


def setup_distributed_environment(rank, world_size):
    rank = int(os.getenv("RANK", rank))
    world_size = int(os.getenv("WORLD_SIZE", world_size))
    dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def load_fsdp_model(fsdp_model, checkpoint_path):
    state_dict = torch.load(checkpoint_path, map_location="cpu")
    fsdp_model.load_state_dict(state_dict)


def load_sharded_checkpoint(fsdp_model, checkpoint_dir):
    # This assumes that each process knows which shard it needs to load
    rank = dist.get_rank()
    shard_file = f"{checkpoint_dir}/__{rank}_0.distcp"
    state_dict = torch.load(shard_file, map_location="cpu")
    fsdp_model.load_state_dict(state_dict)
    
def load_sharded_model_single_gpu(model, model_path):
    
    state_dict = {
        "model": model.state_dict()
    }
    
    dist_cp.load_state_dict(
                state_dict=state_dict,
                storage_reader=dist_cp.FileSystemReader(model_path),
                no_dist=True,
            )
    
    result = model.load_state_dict(state_dict["model"])
    
    print(f"Sharded state checkpoint loaded from {model_path}")
    print(result)
    return model

def convert_fsdp_checkpoint(hf_model, fsdp_model_path, consolidated_model_path, tokenizer):
    '''
    hf_model: transformers path.
    fsdp_model_path: path to the fsdp checkpoint, for example `/x/checkpoint-xxx/pytorch_model_x`
    output_path: output path to save the converted checkpoint
    '''
    config = AutoConfig.from_pretrained(hf_model, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(hf_model, trust_remote_code=True)
    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
    model = load_sharded_model_single_gpu(model, fsdp_model_path)
    tokenizer.save_pretrained(consolidated_model_path)
    #save the FSDP sharded checkpoints in HF format
    model.save_pretrained(consolidated_model_path)
    print(f"saving consolidated model to {consolidated_model_path}")
    #return model