from fairseq import checkpoint_utils, models, optim, utils
from fairseq.models.roberta import RobertaModel

# import sys
# sys.path.append('/home/mengjiao/Documents/workspace_2023/python/byte_subword')


from byte_roberta import ByteWordRobertaModel

import torch

from fairseq.trainer import Trainer

# from fairseq.fairseq.models.roberta import RobertaModel
# from byte_aggregate.byte_roberta import ByteWordRobertaModel
# import torch

from torchsummary import summary

import logging

logger = logging.getLogger(__name__)

model_path = '/home/mengjiao/Documents/workspace_2023/python/byte_subword/roberta_glue/pretrained_model/'
load_model_name = 'roberta.large'
save_model_name = 'byte_roberta.large'
model_name = 'byteword_roberta'

user_dir = '/home/mengjiao/Documents/workspace_2023/python/byte_subword/byte_aggregate'

roberta = RobertaModel.from_pretrained(model_path + load_model_name, checkpoint_file='model.pt')
# print(roberta.model.cuda())
# exit()
# logger.info(roberta.task.state_dict())
# exit()

# trainer = Trainer(roberta.model.args, roberta.task, roberta.model, criterion='label_smoothed_cross_entropy', quantizer=None)

state = checkpoint_utils.load_checkpoint_to_cpu(model_path + load_model_name + '/model.pt')
# print(state["criterion"])
# print(state.keys())
# print(state["extra_state"])
# # # print(state["optimizer_history"])
# # #  print(state["model"])
# exit()



byte_roberta = ByteWordRobertaModel.from_pretrained(model_path + load_model_name)

embed_model_name = "byte_embed"
byte_embed_weights = torch.load(model_path + embed_model_name)

# print(byte_embed_weights['byte_embedding.weight'][100])

# print(byte_roberta.model.encoder.sentence_encoder.embed_tokens.byte_embedding.weight[100])



# print(byte_roberta.model.cuda())
# print(byte_roberta.model.state_dict().keys())

# for name, param in byte_roberta.named_parameters():
#     if param.requires_grad:
#         print(name)

# state_b = checkpoint_utils.load_checkpoint_to_cpu(model_path + save_model_name)

# model = state_b["model"]
# print(model.keys())


# exit()


# for name, param in byte_roberta.named_parameters():
#     if param.requires_grad:
#         print(name)
# exit()

# for name, param in param.items():
#     print(name, param.requires_grad)
# exit()

byte_roberta.model.encoder.sentence_encoder.embed_tokens.load_state_dict(byte_embed_weights)

# print(byte_roberta.model.encoder.sentence_encoder.embed_tokens.byte_embedding.weight[100])

# exit()

param = byte_roberta.model.state_dict()
del param["encoder.lm_head.weight"]

state_dict = {
        "args": byte_roberta.model.args,
        "task_state": byte_roberta.task.state_dict() if byte_roberta.task is not None else {},
        "model": param,
        "optimizer_history": state["optimizer_history"],
        "extra_state": state["extra_state"],
    }


# print(state_dict["model"]['encoder.sentence_encoder.embed_tokens.byte_embedding.weight'][100])
# exit()

# logger.info(byte_roberta.task.state_dict())

# logger.info(f"Saving checkpoint to" +  model_path + load_model_name)
# call state_dict on all ranks in case it needs internal communication
# state_dict = utils.move_to_cpu(state_dict)
# state_dict["extra_state"].update(extra_state)

checkpoint_utils.torch_persistent_save(
    state_dict,
    filename=model_path + save_model_name
)



# print(state_dict["model"])
# logger.info(state_dict["args"])
# logger.info(state_dict["model"])
# exit()

# input = torch.Tensor([100, 58, 3323]).to(torch.long)
# r_emb = roberta.model.encoder.sentence_encoder.embed_tokens(input)
# br_emb = byte_roberta.model.encoder.sentence_encoder.embed_tokens(input)

# print(r_emb)
# print(br_emb)

# exit()




# logger.info(f"Finished saving checkpoint to {os.path.abspath(filename)}")




# print(byte_roberta.model.cuda())
# exit()

# torch.save(byte_roberta.state_dict(), model_path + save_model_name)
# torch.save(byte_roberta, model_path + save_model_name)