#%%
import torch
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
from distributed_training import train_sae
from sae import *
from distibuted_activation_store import DistributedActivationsStore
from sae import TrainingConfig, SAEConfig
from sae.config import post_init_cfg

from datasets import load_dataset
from transformers import DataCollatorWithPadding

import pyrallis
import ast
import os

from datetime import timedelta, datetime

from torch.utils.data import DistributedSampler

import accelerate
from accelerate import (
    AutocastKwargs,
    DistributedDataParallelKwargs,
    InitProcessGroupKwargs,
    DataLoaderConfiguration
)
from accelerate import FullyShardedDataParallelPlugin
from accelerate import Accelerator, init_empty_weights, infer_auto_device_map, dispatch_model
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    FullOptimStateDictConfig,
    FullStateDictConfig,
)

from sae.train_utils import set_seed

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('high')
os.environ["TOKENIZERS_PARALLELISM"] = "false"

@pyrallis.wrap()
def main(cfg: TrainingConfig):
    if cfg.sweep_pair:
        for key, value in cfg.sweep_pair.items():
            setattr(cfg, key, value)

    # Wrap any submodule ≥ 10M parameters (adjust threshold as needed)
    # auto_wrap = partial(size_based_auto_wrap_policy, min_num_params=int(1e7))

    ### init accelerate
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
    init_process_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=3600))

    dataloader_config = DataLoaderConfiguration(split_batches=True)

    # fsdp_plugin = FullyShardedDataParallelPlugin(
    #     sharding_strategy = "no_wrap",
    #     auto_wrap_policy="SIZE_BASED_WRAP",
    #     state_dict_config=FullStateDictConfig(
    #         offload_to_cpu=True, rank0_only=False
    #     ),
    #     optim_state_dict_config=FullOptimStateDictConfig(
    #         offload_to_cpu=True, rank0_only=False
    #     ),
    # )

    accelerator = accelerate.Accelerator(
        mixed_precision=None,
        log_with="wandb",
        kwargs_handlers=[ddp_kwargs, init_process_kwargs], #[autocast_kwargs, ddp_kwargs, init_process_kwargs],
        step_scheduler_with_optimizer=True,
        #fsdp_plugin=fsdp_plugin,
        dataloader_config=dataloader_config,
    )

    splitted_batch_size = cfg.batch_size // accelerator.num_processes
    accelerator.print(f"[*] Number of available gpus: {accelerator.num_processes}[*], new batch size will be - {splitted_batch_size} instead of - {cfg.batch_size} [*]")
    cfg.batch_size = splitted_batch_size
    cfg.num_batches_in_buffer = cfg.num_batches_in_buffer // accelerator.num_processes
    #cfg.lr = cfg.lr / accelerator.num_processes

    accelerator.print(f"[*] Model initialization 🤖 [*]")
    cfg.device = accelerator.device
    print("--------------->", cfg.device)

    loaded_model = False
    while not loaded_model:
        try:
            model = HookedTransformer.from_pretrained_no_processing(cfg.model_name, device=cfg.device, dtype="bfloat16", low_cpu_mem_usage=True).to(cfg.device)
            #model = HookedTransformer.from_pretrained_no_processing(cfg.model_name, device=cfg.device, low_cpu_mem_usage=True).to(cfg.device) #, n_devices=2) #!!!!!!!!
            loaded_model = True
        except Exception as e:
            import time
            print(f"Error loading model: {e}")
            time.sleep(10)

    cfg.act_size = model.cfg.d_model

    accelerator.print(f"[*] Model 🤖 successfully initialized with act_size = {cfg.act_size} 🦾 [*] ")

    accelerator.print("[*] ActivationsStore initialization 🦾 [*] ")
    ### init activation store
    if "," in cfg.hook_point:
        cfg.hook_point = ast.literal_eval(cfg.hook_point)
        if  isinstance(cfg.hook_point , list):
            cfg.hook_point = tuple(cfg.hook_point)

    ### get dataset 
    dataset = load_dataset(cfg.dataset_path, split="train", streaming=True, trust_remote_code=True)

    def filtering_collate(batch):
        batch = [ex for ex in batch if ex is not None]

        texts = [ex["text"] for ex in batch]
        return texts

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=cfg.model_batch_size,
        shuffle=False,
        num_workers=16,
        collate_fn= filtering_collate #collator
    )
    print("---------", model.cfg.n_ctx)
    activations_store = DistributedActivationsStore(model.cfg.n_ctx, cfg, num_processes=accelerator.num_processes)
    accelerator.wait_for_everyone()
    dataloader = accelerator.prepare(dataloader, device_placement=[False])
    iterator_dl = iter(dataloader)

    activations_store.fill_buffer(model, iterator_dl)
    cfg = post_init_cfg(cfg, activations_store)
    accelerator.wait_for_everyone()

    accelerator.print(f"[*] ActivationsStore successfully initialized 🦾 [*] ")

    accelerator.init_trackers(
        project_name=(
            cfg.wandb_project
        ),
        config=vars(cfg),
        init_kwargs={"wandb": {"name": f"{cfg.dict_size}_{cfg.sae_type}_{cfg.wandb_run_suffix}", "allow_val_change": True}},
    )

    accelerator.print("[*] SAE initialization 🦾 [*] ")
    
    model_config = SAEConfig.from_training_config(cfg)
    if cfg.sae_type == "vanilla":
        sae = VanillaSAE(model_config)
    elif cfg.sae_type == "topk":
        sae = TopKSAE(model_config)
    elif cfg.sae_type == "batchtopk":
        sae = BatchTopKSAE(model_config)
    elif cfg.sae_type == 'jumprelu':
        sae = JumpReLUSAE(model_config)
    elif cfg.sae_type == "kronsae":
        model_config.cartesian_op = "mul"
        sae = KronSAE(model_config)
    else:
        raise ValueError(f"SAE type {cfg.sae_type} is not understood.")
    
    accelerator.print(f"[*] SAE successfully initialized 🦾 [*] ")
    
    set_seed(cfg.seed)

    #optim_groups = configure_optimizers(sae, cfg)
    optimizer = torch.optim.Adam(
        params=sae.parameters(),
        lr=cfg.lr,
        betas=(0.9, 0.99),
    )
    is_train_transcoder = isinstance(cfg.hook_point , tuple)
    
    train_sae(sae, activations_store, iterator_dl, optimizer, model, cfg, accelerator,train_transcoder=is_train_transcoder)


if __name__ == "__main__":
    main()