import hydra
import os
import torch
from src.model_stacklayer_llama_v9 import CustomLlamaConfig, LlamaMem
from src.dataloader import DistributedDataLoader
from datetime import datetime
from src.trainer import Trainer
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, AutoConfig

from src.utils import print0, set_seed
if int(os.environ.get("RANK", 0)) == 0:
    import wandb

def ddp_setup():
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


def load_model_state(model, optimizer_state_dict, rank, train_config):
    
    if optimizer_state_dict:
        try:
            optimizer = model.configure_optimizers(train_config)
            optimizer.load_state_dict(optimizer_state_dict)
            
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(rank)
            print0("Successfully loaded optimizer state")
            return optimizer
        except Exception as e:
            print0(f"Warning: Failed to load optimizer state: {str(e)}")
            return model.configure_optimizers(train_config)
    return model.configure_optimizers(train_config)


def verify_model_state(model, rank):
    
    for name, param in model.named_parameters():
        if not param.is_cuda:
            raise RuntimeError(f"Parameter {name} is not on CUDA device {rank}")
    print0(f"Verified all model parameters are on device {rank}")


@hydra.main(config_path="configs/", config_name="train", version_base=None)
def main(cfg):
    print0(f"=====>Script arguments:\n{cfg}")
    ddp_setup()

    current_time = datetime.now().strftime("%Y%m%d_%H%M%S")

    
    set_seed(cfg.seed)


    
    model_config = LlamaMem.get_default_config()
    model_config.model_type = cfg.model.model_type
    model_config.model_name = cfg.model.model_name
    model_config.vocab_size = cfg.model.vocab_size
    model_config.block_size = cfg.model.sequence_length
    model_config.max_length = cfg.model.max_length
    
    model_config.stack_slots = cfg.model.stack_slots
    model_config.n_layer = cfg.model.n_layer
    model_config.n_head = cfg.model.n_head
    model_config.n_embd = cfg.model.n_embd
    model_config.use_memory = cfg.model.use_memory
    model_config.log_freq = cfg.train.log_freq
    model_config.beta_coeff = cfg.model.beta_coeff
    model_config.seq_length = cfg.model.sequence_length
    model_config.num_mem_heads = cfg.model.num_mem_heads
    model_config.batch_size = cfg.train.batch_size

    
    train_config = Trainer.get_default_config()
    train_config.max_iters = cfg.train.max_iters
    train_config.num_workers = 0
    train_config.ckpt_dir = cfg.train.ckpt_dir
    train_config.max_ckpts_to_keep = cfg.train.max_ckpts_to_keep
    train_config.log_freq = cfg.train.log_freq
    train_config.save_freq = cfg.train.save_freq
    train_config.dtype = cfg.train.dtype
    train_config.learning_rate = cfg.train.learning_rate
    train_config.learning_rate_decay_frac = cfg.train.lr_decay_frac
    train_config.warmup_iters = cfg.train.warmup_iters
    train_config.batch_size = cfg.train.batch_size
    train_config.num_mem_heads = cfg.model.num_mem_heads
    train_config.stack_slots = cfg.model.stack_slots

    
    rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    B = cfg.train.batch_size
    T = cfg.model.sequence_length



    rank = int(os.environ["RANK"])
    
    
    wandb_model_config = dict(cfg) 
    
    
    if rank == 0 and cfg.wandb.enable:
        cfg.wandb.name = "stack_v7"
        wandb_name =f"{os.environ.get('WANDB_NAME', cfg.wandb.name)}"  
        print("running...:", wandb_name)
        wandb.init(
            project=cfg.wandb.project,
            name=f"{wandb_name}_{current_time}",
            config=dict(wandb_model_config)
        )
        wandb.config.update({"world_size": int(os.environ["WORLD_SIZE"])})

    rank = int(os.environ["LOCAL_RANK"])

    
    train_loader = DistributedDataLoader(cfg.input_bin, B, T, rank, world_size)
    valid_dataset = DistributedDataLoader(cfg.input_val_bin, B, T, rank, world_size)

    iter_num = 0
    optimizer_state_dict = None

    try:
        if cfg.model.model_type == "llama":
            print(cfg.model.model_name)
            base_config = AutoConfig.from_pretrained(cfg.model.model_name).to_dict()
            config = CustomLlamaConfig(
                use_memory=model_config.use_memory,
                stack_slots=model_config.stack_slots,
                num_mem_heads=model_config.num_mem_heads,
                **base_config,
            )

            tokenizer = AutoTokenizer.from_pretrained(cfg.model.model_name)
            tokenizer.pad_token = tokenizer.eos_token

            if cfg.pretrain.snapshot_path:
                print0(f"Loading model from checkpoint: {cfg.pretrain.snapshot_path}")
                model, optimizer_state_dict, iter_num = LlamaMem.from_ckpt(
                    cfg.pretrain.snapshot_path,
                    config=config,
                    tokenizer=tokenizer,
                    rank=rank,
                    load_memory=cfg.pretrain.load_mem,
                    resume_training=True,
                )
                print0(f"Successfully loaded checkpoint at iteration {iter_num}")
            else:
                print0(f"Initializing new model from {cfg.model.model_name}")
                model = LlamaMem.from_config(
                    config=config,
                    tokenizer=tokenizer,
                )
            print("model parameter size:")
            print(sum(p.numel() for p in model.parameters()))
            print("billion parameter size:")
            print(sum(p.numel() for p in model.parameters()) / 1e9, "B")
        

    except Exception as e:
        raise RuntimeError(f"Failed to initialize model: {str(e)}")

    
    model = model.to(rank)
    verify_model_state(model, rank)

    
    optimizer = load_model_state(model, optimizer_state_dict, rank, train_config)

    
    model = DDP(model, find_unused_parameters=False)

    
    trainer = Trainer(
        train_config,
        model,
        optimizer,
        train_loader=train_loader,
        local_rank=rank,
        grad_accum_steps=1,
        iter_num=iter_num,
    )

    def batch_end_callback(trainer):
        if trainer.iter_num % trainer.log_freq == 0:
            
            valid_loader = valid_dataset
            valid_loader.reset()

            trainer.model.eval()
            with torch.no_grad():
                val_loss = 0.0
                for i in range(100):
                    valid_loader.set_epoch(i)
                    x, y = valid_loader.next_batch()
                    x = x.to(trainer.device)
                    y = y.to(trainer.device)

                    _, loss, _ = model(x, y)
                    val_loss += loss.item()
                val_loss /= 100

                
                val_loss_tensor = torch.tensor(val_loss, device=trainer.device)
                train_loss_tensor = torch.tensor(trainer.loss.item(), device=trainer.device)

                
                torch.distributed.all_reduce(val_loss_tensor, op=torch.distributed.ReduceOp.SUM)
                torch.distributed.all_reduce(train_loss_tensor, op=torch.distributed.ReduceOp.SUM)

                
                world_size = torch.distributed.get_world_size()
                val_loss = val_loss_tensor.item() / world_size
                train_loss = train_loss_tensor.item() / world_size

            if int(os.environ["RANK"]) == 0:
                print0(f"iter {trainer.iter_num}: train loss {train_loss:.5f}, valid loss {val_loss:.5f}")

            if int(os.environ["RANK"]) == 0 and cfg.wandb.enable:
                wandb.log({
                    "train_loss": train_loss,
                    "val_loss": val_loss,
                    "iter": trainer.iter_num,
                    "learning_rate": trainer.optimizer.param_groups[0]['lr']
                }, step=trainer.iter_num)
            
            
            
            for detect_layer in [0, 4, 8, 12]:
                if len(trainer.model.module.model.layers) <= detect_layer:
                    continue
                layer0_model = trainer.model.module.model.layers[detect_layer]
                if hasattr(layer0_model, 'mem_stack'):  
                    mem_module = layer0_model.mem_stack
                    
                    
                    sync_tensor = torch.tensor([
                        sum(mem_module.monitor_data['push_weights']),
                        sum(mem_module.monitor_data['pop_weights']), 
                        sum(mem_module.monitor_data['noop_weights']),
                        len(mem_module.monitor_data['push_weights'])
                    ], device=trainer.device)
                    
                    
                    torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM)
                    
                    
                    total_samples = sync_tensor[3].item()
                    if total_samples > 0:
                        global_push = sync_tensor[0].item() / total_samples
                        global_pop = sync_tensor[1].item() / total_samples
                        global_noop = sync_tensor[2].item() / total_samples
                        
                        if int(os.environ["RANK"]) == 0 and cfg.wandb.enable:
                            wandb.log({
                                "action/push_" + str(detect_layer): global_push,
                                "action/pop_" + str(detect_layer): global_pop,
                                "action/noop_" + str(detect_layer): global_noop
                            }, step=trainer.iter_num)
                    
                    
                    mem_module.monitor_data['push_weights'].clear()
                    mem_module.monitor_data['pop_weights'].clear()
                    mem_module.monitor_data['noop_weights'].clear()
            
            
            if hasattr(trainer.model.module, "router"):
                layer_model = trainer.model.module
                if True:
                    router_module = layer_model.router
                    
                    
                    sync_tensor = torch.tensor([
                        sum(router_module.monitor_data['router_0_scores']),
                        sum(router_module.monitor_data['router_1_scores']), 
                        len(router_module.monitor_data['router_0_scores'])
                    ], device=trainer.device)
                    
                    
                    
                    torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM)
                    
                    
                    total_samples = sync_tensor[2].item()
                    if total_samples > 0:
                        global_router_0_scores = sync_tensor[0].item() / total_samples
                        global_router_1_scores = sync_tensor[1].item() / total_samples
                        
                        if int(os.environ["RANK"]) == 0 and cfg.wandb.enable:
                            wandb.log({
                                "router/scores_0": global_router_0_scores,
                                "router/scores_1": global_router_1_scores,
                            }, step=trainer.iter_num)
                    
                    
                    router_module.monitor_data['router_0_scores'].clear()
                    router_module.monitor_data['router_1_scores'].clear()

            if hasattr(trainer.model.module, "outer_stack"):
                layer_model = trainer.model.module
                if True:
                    outer_stack = layer_model
                    
                    
                    sync_tensor = torch.tensor([
                        sum(outer_stack.monitor_data['exit_layer_num']),
                        sum(outer_stack.monitor_data['max_exit_layer_num']),
                        len(outer_stack.monitor_data['exit_layer_num'])
                    ], device=trainer.device)
                    
                    
                    
                    torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM)
                    
                    
                    total_samples = sync_tensor[2].item()
                    if total_samples > 0:
                        global_exit_layer_num = sync_tensor[0].item() / total_samples
                        global_max_exit_layer_num = sync_tensor[1].item() / total_samples
                        if int(os.environ["RANK"]) == 0 and cfg.wandb.enable:
                            wandb.log({
                                "outer/iter": global_exit_layer_num,
                                "outer/max_iter": global_max_exit_layer_num,
                            }, step=trainer.iter_num)
                    
                    
                    outer_stack.monitor_data['exit_layer_num'].clear()
                    outer_stack.monitor_data['max_exit_layer_num'].clear()

            if hasattr(trainer.model.module, "outer_stack"):
                layer_model = trainer.model.module
                if True:
                    mem_module = layer_model.outer_stack
                    
                    sync_tensor = torch.tensor([
                        sum(mem_module.monitor_data['push_weights']),
                        sum(mem_module.monitor_data['pop_weights']), 
                        sum(mem_module.monitor_data['noop_weights']),
                        len(mem_module.monitor_data['push_weights'])
                    ], device=trainer.device)
                    
                    
                    torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM)
                    
                    
                    total_samples = sync_tensor[3].item()
                    if total_samples > 0:
                        global_push = sync_tensor[0].item() / total_samples
                        global_pop = sync_tensor[1].item() / total_samples
                        global_noop = sync_tensor[2].item() / total_samples
                        
                        if int(os.environ["RANK"]) == 0 and cfg.wandb.enable:
                            wandb.log({
                                "outer_action/push": global_push,
                                "outer_action/pop": global_pop,
                                "outer_action/noop": global_noop
                            }, step=trainer.iter_num)
                    
                    
                    mem_module.monitor_data['push_weights'].clear()
                    mem_module.monitor_data['pop_weights'].clear()
                    mem_module.monitor_data['noop_weights'].clear()

    

    trainer.set_callback("on_batch_end", batch_end_callback)

    trainer.run(current_time, iter_num)

    destroy_process_group()


if __name__ == "__main__":
    main()
