import torch
import torch.distributed as dist
from torch.distributed.pipelining import ScheduleGPipe
from torch.distributed.pipelining import PipelineStage
from transformers import AutoConfig, LlamaForCausalLM, AutoTokenizer 
from datasets import load_dataset
from torch.utils.data import DataLoader
import os
import torch.nn.functional as F
from transformers import DataCollatorForLanguageModeling
from tqdm import tqdm
import yaml
import datetime
from torch.profiler import profile, record_function, ProfilerActivity
from typing import Optional, Union, List
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.cache_utils import DynamicCache 
import torch.nn as nn


def init_process_group():
    rank = int(os.environ.get("LOCAL_RANK", -1))
    world_size = int(os.environ.get("WORLD_SIZE", -1))
    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=rank,
        timeout=datetime.timedelta(seconds=1200)
    )
    torch.cuda.set_device(rank)
    #os.environ['NCCL_SOCKET_TIMEOUT'] = '1200'
    #os.environ['NCCL_IB_TIMEOUT'] = '1200'
    #os.environ['NCCL_DEBUG'] = 'INFO'
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    os.environ['TORCH_USE_CUDA_DSA'] = '1'

def load_config(config_path="llama2_pipeline_config.yaml"):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)


class SplitLlama(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model = SplitLlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **kwargs,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs[0]
        if not self.lm_head:
            return hidden_states
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            #past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            #attentions=outputs.attentions,
        )
    
class SplitLlamaModel(LlamaModel):
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids) if self.embed_tokens else input_ids

        if use_cache and past_key_values is None:
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for decoder_layer in self.layers[: len(self.layers)]:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                    **flash_attn_kwargs,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        #hidden_states = self.norm(hidden_states)
        hidden_states = self.norm(hidden_states) if self.norm is not None else hidden_states

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        output = BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )
        return output if return_dict else output.to_tuple()


def main():
    # Load config
    config = load_config("llama2_pipeline_config.yaml")
    
    # Set random seed
    seed = config.get('seed', 42)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")

    training_config = config.get('training', {})
    cp_config = config.get("compression_config")
    num_stages = cp_config.get("num_stages")
    layer_config = cp_config.get("layer12")
    layer_idx = layer_config.get("layer_idx")
    init_process_group() 
    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained("t5-base")
    model_config = AutoConfig.from_pretrained("configs/llama_1b.json")
    #model = LlamaForCausalLM(model_config)
    model = SplitLlama(model_config).to(device)
    if rank == 0:
        for i in range(layer_idx + 1, model.config.num_hidden_layers):
            del model.model.layers[-1]
        model.model.norm = None
        model.lm_head = None
    elif rank == 1:
        for i in range(layer_idx + 1):
            del model.model.layers[i]
        model.model.embed_tokens = None
    stage = PipelineStage(
        model,
        rank,
        num_stages,
        device,
    )

    tokenizer.pad_token = tokenizer.eos_token
    
    block_size = training_config.get('block_size', 1024)
    batch_size = training_config.get('batch_size', 1)
    
    # Load dataset
    import datasets
    import itertools
    dataset = load_dataset("/data/datasets/wikitext/wikitext-2-raw-v1")
    
    # Tokenize function
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            truncation=True,  # Do not truncate, let group_texts handle lengths
            max_length=block_size,
            padding="max_length",
        )

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False
    )
    # Tokenize
    tokenized_datasets = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset['train'].column_names,
        desc="Running tokenizer on dataset",
    )

    train_dataset = tokenized_datasets['train']


    # Create data loaders with custom collate_fn to handle indices
    def indexed_collator(features):
        batch = data_collator(features)
        if isinstance(features[0], dict) and "indices" in features[0]:
            batch["indices"] = torch.tensor([int(f["indices"]) for f in features])
        return batch

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=indexed_collator
    )
        
    # get schedule
    num_mbs = 4
    schedule = ScheduleGPipe(stage, num_mbs)

    # define opt
    learning_rate = training_config.get('learning_rate', 1e-4)
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    scaler = torch.amp.GradScaler()

    for batch_idx, batch in enumerate(train_dataloader):

        # Progress bar only on main process
        if dist.get_rank() == dist.get_world_size() - 1:
            pbar = tqdm(train_dataloader, desc=f"step")
        else:
            pbar = train_dataloader
        
        count = 0
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            record_shapes=True,
        ) as prof:
            for batch in pbar:
                input_ids = batch["input_ids"].to(f"cuda:{dist.get_rank()}").contiguous()
                labels = batch["labels"].to(f"cuda:{dist.get_rank()}").contiguous()

                with record_function("forward"):
                    with torch.amp.autocast(device_type='cuda'):
                        if dist.get_rank() == 0:
                            output = schedule.step(input_ids)
                        else:
                            output = schedule.step()
        
                    # Compute loss (only on the last stage)
                    loss = None
                    logits = None
                    if dist.get_rank() == dist.get_world_size() - 1:
                        logits = output[0] if isinstance(output, tuple) else output
                        shift_logits = logits.logits[..., :-1, :].contiguous()
                        shift_labels = labels[..., 1:].contiguous() 
                        loss = F.cross_entropy(
                            shift_logits.view(-1, shift_logits.size(-1)),
                            shift_labels.view(-1)
                        )
                with record_function('backward'): 
                    if loss is not None:
                        optimizer.zero_grad()
                        #loss.backward()
                        scaler.scale(loss).backward()
                        optimizer.step()
        
                count += 1
                if count == 500:
                    break
        output_file = f"{config.get('output_dir')}/profiler_{dist.get_rank()}.txt"
        with open(output_file, 'w') as f:
            table_str = prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)
            f.write(table_str)
            print(f"the file has been save to {output_file}")


    

if __name__ == "__main__":
    main()