import os
import sys

os.environ['JAX_PLATFORMS'] = 'cpu'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import jax
import jax.numpy as jnp
from flax import nnx
from layers import Cache
import model
import json
from snapshot import Snapshot
from safetensors import safe_open


def save_merged_model(path: str, instance):
    snapshot = Snapshot(os.path.dirname(path))
    name = os.path.splitext(os.path.basename(path))[0]
    snapshot.save(name, instance)


class WeightAdapter:
    def __init__(self, path: str):
        self.__path = path
        if not os.path.exists(path):
            os.mkdir(path)

    def load_as_dict(self, name: str):
        flat_dict = {}
        state_dict = {}
        
        with safe_open(f'{self.__path}/{name}.safetensors', framework="flax") as f:
            for key in f.keys():
                flat_dict[key] = f.get_tensor(key)

        for key, value in flat_dict.items():
            parts = key.split('.')
            current = state_dict
            for part in parts[:-1]:
                # Convert numeric string keys to integers
                if part.isdigit():
                    part = int(part)
                if part not in current:
                    current[part] = {}
                current = current[part]
            # Convert the final part key if it's numeric
            final_part = parts[-1]
            if final_part.isdigit():
                final_part = int(final_part)
            current[final_part] = value
        
        return state_dict

    def get_reasoner(self, name: str, replace_decoders: dict, instance: model.Reasoner, accu_steps: int):
        pretrained_state_dict = self.load_as_dict(name)
        graphdef, state = nnx.split(instance)
        state['featurizer']['embedding'] = pretrained_state_dict['featurizer']['embedding']
        state['featurizer']['embedding_ema'] = pretrained_state_dict['featurizer']['embedding']
        state['output_norm'] = pretrained_state_dict['norm']
        state['featurizer']['step'] = Cache(jnp.array(accu_steps, dtype=jnp.int8))

        for reasoner_idx, pretrain_idx in replace_decoders.items():
            state['decoders'][reasoner_idx] = pretrained_state_dict['decoders'][pretrain_idx]

        return nnx.merge(graphdef, state)


def main(path:str, base: str, size: str):
    # Load configuration
    key = jax.random.key(0)
    with open('configs/reasoner_cfg.json', 'r') as file:
        config = json.load(file)
        latent_feature = int(config['model_large']['Feature'])
        attn_feature = int(config['model_large']['ATTN Feature'])
        ffn_feature = int(config['model_large']['FFN Feature'])
        num_head = int(config['model_large']['Head Count'])
        reasoner_decoder_count = int(config['model_large']['Decoder Count'])
        init_scalar = float(config['model_large']['Init Scalar'])
        max_len = int(config['model_large']['Max Length'])
        rope_base = float(config['model_large']['RoPE Base'])
        accu_steps = 1

        reasoner = model.Reasoner(
            feature=latent_feature,
            attn_feature=attn_feature,
            ffn_feature=ffn_feature,
            num_head=num_head,
            decoder_count=reasoner_decoder_count,
            is_causal=True,
            init_scalar=init_scalar,
            vocab_size=32,
            ema_interval=accu_steps,
            key=key,
            dtype=jnp.bfloat16
        )
        reasoner.train(
            rope_base=rope_base,
            max_len=max_len
        )

    reasoner_decoder_replace_map_large = {
        0: 0,
        1: 1,
        2: 2,
        3: 3,
        4: 4,
        5: 5,
        6: 6,
        7: 7,
        8: 8,
        9: 9,
        10: 10,
        11: 11,
        12: 12,
        13: 13,
        14: 14,
        15: 15,
    }

    reasoner_decoder_replace_map_middle = {
        0: 0,
        1: 1,
        2: 2,
        3: 3,
        4: 4,
        5: 5,
        6: 6,
        7: 7,
        8: 8,
        9: 9,
        10: 10,
        11: 11,
    }

    reasoner_decoder_replace_map_small = {
        0: 0,
        1: 1,
        2: 2,
        3: 3,
        4: 4,
        5: 5,
    }

    if size == 'large':
        reasoner_decoder_replace_map = reasoner_decoder_replace_map_large
    elif size == 'middle':
        reasoner_decoder_replace_map = reasoner_decoder_replace_map_middle
    elif size == 'small':
        reasoner_decoder_replace_map = reasoner_decoder_replace_map_small
    else:
        print(f'{size} is not supported.')
        sys.exit(1)


    adapter = WeightAdapter(path)
    reasoner = adapter.get_reasoner(base, reasoner_decoder_replace_map, reasoner, accu_steps)
    save_merged_model(f'snapshot/cfg_completion/reasoner/init_{size}.safetensors', reasoner)


if __name__ == '__main__':
    model_size = sys.argv[1]
    pretrain_model_path = sys.argv[1]
    pretrain_snapshot_name = sys.argv[2]
    main(pretrain_model_path, pretrain_snapshot_name, model_size)
