import torch
# import pytorch_lightning as pl
# from lightning_transformers.task.nlp.language_modeling import (
#     LanguageModelingDataModule,
#     LanguageModelingTransformer,
# )
import transformers
# from pytorch_lightning import LightningModule
# from pytorch_lightning.utilities import rank_zero_warn
import torch.nn as nn
# from torch.distributions.gamma import Gamma
from data.utils import discretize_time
# from torch.nn import CrossEntropyLoss
# from data.dms import DMSData
# from data.constants import REF_RBD_SEQ
from models import register_model
import math, logging, copy
# from models.criterions.cross_entropy import CEWeightedLoss
# from esm import pretrained
from typing import IO, Any, Callable, Dict, Optional, Tuple, Type, Union
# from utils.args import str2bool
# from transformers import GPT2LMHeadModel
from transformers import AutoConfig, PreTrainedTokenizerBase
from models.gpt2_new import GPT2TimeNew
# from collections import defaultdict
import loralib as lora
# from models.lora.model import GPT2Config, GPT2LMModel
from models.gpt2_multihosts import GPTOutputs
from transformers.models.gpt2.modeling_gpt2_lora import LoRAGPT2LMHeadModel

class GPT2LoRATime(LoRAGPT2LMHeadModel):
    def __init__(self, config) -> None:
        super().__init__(config)
        # print(config.lora_attn_dim)
        self.config = config
        self.build_offset_layer(config)
        self.apply(self._init_weights)
        self.load_pretrain_weight(config.pretrained_model_path)
    
    @classmethod
    def load_weight(cls, state_dict, model):
        if 'model_state_dict' in state_dict:
            state_dict = state_dict['model_state_dict']
    
        state_dict_tmp = copy.deepcopy(state_dict)
        old_keys = []
        new_keys = []
        for key in state_dict_tmp:
            new_key = None
            if key.endswith(".g"):
                new_key = key[:-2] + ".weight"
            elif key.endswith(".b"):
                new_key = key[:-2] + ".bias"
            elif key.endswith(".w"):
                new_key = key[:-2] + ".weight"
            
            if key.startswith("module.transformer."):
                new_key = key[len("module.transformer."):]

            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)

        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)
        
        for n, p in model.transformer.named_parameters():
            if n not in state_dict:
                state_dict[n] = p

        model.transformer.load_state_dict(state_dict, strict=False)
    
    def set_tied(self):
        """ Make sure we are sharing the embeddings"""
        self.lm_head.set_embeddings_weights(self.transformer.wte.weight)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def load_pretrain_weight(self, pretrained_model_path):
        # for old_key, new_key in zip(old_keys, new_keys):
            # state_dict[new_key] = state_dict.pop(old_key)
        
        # transformer_keys = []
        transformer_state_dict = dict()
        offset_state_dict = dict()

        big_model_state_dict = torch.load(pretrained_model_path)["state_dict"]
        for key in big_model_state_dict:
            new_key = None
            if key.startswith("model.transformer."):
                new_key = key[len("model.transformer."):]
                transformer_state_dict[new_key] = big_model_state_dict[key]
            elif key.startswith("model.offset_layer.transformer."):
                new_key = key[len("model.offset_layer.transformer."):]
                offset_state_dict[new_key] = big_model_state_dict[key]
            elif key.startswith("model.offset_layer."): # offset's lm_head
                new_key = key[len("model.offset_layer."):]
                # print(new_key)
                offset_state_dict[new_key] = big_model_state_dict[key]
            elif key.startswith("model."): # model's lm_head:
                new_key = key[len("model."):]
                # print(",", new_key)
                transformer_state_dict[new_key] = big_model_state_dict[key]
            else:
                logging.info("Could not load parameters in pretrained model %s" % new_key)

        # print(offset_state_dict)
        # exit()
        # print("===========")
        
        self.load_weight(transformer_state_dict, self)
        if self.config.lora_attn_dim > 0:
            lora.mark_only_lora_as_trainable(self)


        # # debug
        # for n, p in self.transformer.named_parameters():
        #     # print(n, p.requires_grad)
        #     if "model.transformer." + n in big_model_state_dict:
        #         print(torch.allclose(p.data.cpu(), big_model_state_dict["model.transformer." + n].data.cpu()))
        #     else:
        #         print("Not in original model", n)
        # # debug
        # exit()
        
        # print(list(self.offset_layer.named_parameters()))

        if getattr(self.config, "transformer_offset", False):
            self.load_weight(offset_state_dict, self.offset_layer)
            if self.config.lora_attn_dim > 0:
                lora.mark_only_lora_as_trainable(self.offset_layer)
        else: # Just a linear layer!
            self.offset_layer.load_state_dict(offset_state_dict)
            # print(self.offset_layer.weight.requires_grad)
            for param in self.offset_layer.parameters():
                param.requires_grad = True    
            # print(self.offset_layer.weight.requires_grad)
            # exit()

        # print(old_keys)

        # # debug
        # for n, p in self.offset_layer.transformer.named_parameters():
        #     if "model.offset_layer.transformer." + n in big_model_state_dict:
        #         print("requires_grad:", p.requires_grad, "Loaded pretrained:", torch.allclose(p.data.cpu(), big_model_state_dict["model.offset_layer.transformer." + n].data.cpu()))
        #     else:
        #         print("Not in original model:", n, "requires_grad:", p.requires_grad)
        # # debug
        # exit()

    def build_offset_layer(self, config):
        if getattr(self.config, "transformer_offset", False):
            self.offset_layer = LoRAGPT2LMHeadModel(config) # Also from LoRA
        else:
            self.offset_layer = nn.Linear(config.hidden_size, config.vocab_size) # Just a linear layer

    @classmethod
    def from_config(cls, config):
        model = cls(config)
        return model
    
    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
        """
        Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
        """
        return {"input_ids": input_ids, "input_time": kwargs["input_time"]}

    def get_offset(self, outputs=None, **argv):
        if self.config.transformer_offset:
            outputs = self.offset_layer.forward(input_ids = argv.get("input_ids"), labels = argv.get("labels"), \
                attention_mask = argv.get("attention_mask"), output_hidden_states=True)
            offset = outputs.logits
        else:
            hidden_states = outputs.hidden_states[-1]
            offset = self.offset_layer(hidden_states)
        
        return offset
    
    def forward(self, input_time, **argv):
        time = discretize_time(
            input_time, 
            one_step=False, 
            normalize_time_a=self.config.normalize_time_a, 
            normalize_time_b=self.config.normalize_time_b,
            discrete=False)
        beam_size = argv.get("input_ids").size(0) // input_time.size(0)
        time = time.unsqueeze(1).repeat(1, beam_size).view(-1)
        outputs = super().forward(input_ids = argv.get("input_ids"), labels = argv.get("labels"), \
            attention_mask = argv.get("attention_mask"), output_hidden_states=True)
        
        rate = outputs.logits
        logits = rate * time.unsqueeze(-1).unsqueeze(-1)
        offset = self.get_offset(outputs, **argv)
        logits = logits + offset
        outputs.logits = logits
        return outputs


@register_model("gpt2_time_lora")
# class GPT2TimeNew(LightningModule): 
class GPT2TimeLoRA(GPT2TimeNew):
    def __init__(self, config, alphabet, **kwargs) -> None:
        super().__init__(config, alphabet, **kwargs)

    def on_fit_start(self):
        tokenizer_length = len(self.tokenizer)
        # self.model.resize_token_embeddings(tokenizer_length)

    def initialize_model(self, pretrained_model_name_or_path: str):
        """create and initialize the model to use with this task,
        Feel free to overwrite this method if you are initializing the model in a different way
        """
        config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path="gpt2", **self.model_data_kwargs
        )
        # print(self.model_data_kwargs)
        # print(_config)
        # config = GPT2Config(
        #     n_embd=_config.n_embd, n_layer=self.config.num_hidden_layers, n_head=_config.n_head, 
        #     vocab_size_or_config_json_file=_config.vocab_size, n_positions=_config.n_positions, n_ctx=_config.n_ctx,
        #     lora_attn_dim=self.config.lora_dim, 
        #     lora_attn_alpha=self.config.lora_alpha, 
        #     lora_dropout=self.config.lora_dropout,
        # )
        # print(config.n_embd, config.n_head, config.n_layer, config.n_positions, config.n_ctx, config.vocab_size)

        # LoRA
        setattr(config, "lora_attn_dim", self.config.lora_dim)
        setattr(config, "lora_attn_alpha", self.config.lora_alpha)
        setattr(config, "lora_dropout", self.config.lora_dropout)

        setattr(config, "normalize_time_a", self.config.normalize_time_a)
        setattr(config, "normalize_time_b", self.config.normalize_time_b)
        setattr(config, "transformer_offset", self.config.transformer_offset)
        setattr(config, "pretrained_model_path", self.config.load_from_pretrain_checkpoint)
        setattr(config, "label_smooth", self.config.label_smooth)
        
        self.model = GPT2LoRATime.from_config(config)

    def load_pretrained_model(self, path):
        pass

    @classmethod
    def add_argparse_args(cls, parent_parser):
        parent_parser = super(GPT2TimeLoRA, cls).add_argparse_args(parent_parser)
        # LoRA
        parent_parser.add_argument('--lora_dim', type=int, default=4, help='lora attn dimension')
        parent_parser.add_argument('--lora_alpha', type=int, default=32, help='lora attn alpha')
        parent_parser.add_argument('--lora_dropout', type=float, default=0.1, help='dropout probability for lora layers')
        parent_parser.add_argument('--label_smooth', default=0.0, type=float, help='label smoothing')

        return parent_parser
        