import transformers
import os
import torch
import torch.nn as nn
import re
import logging
from nn import FixableDropout
from utils import scr


LOG = logging.getLogger(__name__)


class CastModule(nn.Module):
    def __init__(self, module: nn.Module, in_cast: torch.dtype = torch.float32, out_cast: torch.dtype = None):
        super().__init__()

        self.underlying = module
        self.in_cast = in_cast
        self.out_cast = out_cast

    def cast(self, obj, dtype):
        if dtype is None:
            return obj

        if isinstance(obj, torch.Tensor):
            return obj.to(dtype)
        else:
            return obj

    def forward(self, *args, **kwargs):
        args = tuple(self.cast(a, self.in_cast) for a in args)
        kwargs = {k: self.cast(v, self.in_cast) for k, v in kwargs.items()}
        outputs = self.underlying(*args, **kwargs)
        if isinstance(outputs, torch.Tensor):
            outputs = self.cast(outputs, self.out_cast)
        elif isinstance(outputs, tuple):
            outputs = tuple(self.cast(o, self.out_cast) for o in outputs)
        else:
            raise RuntimeError(f"Not sure how to cast type {type(outputs)}")
        return outputs

    def extra_repr(self):
        return f"in_cast: {self.in_cast}\nout_cast: {self.out_cast}"


class BertClassifier(torch.nn.Module):
    def __init__(self, model_name, hidden_dim=768):
        super().__init__()
        if model_name.startswith("bert"):
            self.model = transformers.BertModel.from_pretrained(model_name, cache_dir='./hugging_cache')
        else:
            self.model = transformers.AutoModel.from_pretrained(model_name, cache_dir='./hugging_cache')
        self.classifier = torch.nn.Linear(hidden_dim, 1)

    @property
    def config(self):
        return self.model.config

    def forward(self, *args, **kwargs):
        filtered_kwargs = {k: v for k, v in kwargs.items() if k != "labels"}
        model_output = self.model(*args, **filtered_kwargs)
        if "pooler_output" in model_output.keys():
            pred = self.classifier(model_output.pooler_output)
        else:
            pred = self.classifier(model_output.last_hidden_state[:, 0])

        if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]:
            last_hidden_state = model_output.last_hidden_state
            return pred, last_hidden_state
        else:
            return pred


def replace_dropout(model):
    for m in model.modules():
        for n, c in m.named_children():
            if isinstance(c, nn.Dropout):
                setattr(m, n, FixableDropout(c.p))

    def resample(m, seed=None):
        for c in m.children():
            if hasattr(c, "resample"):
                c.resample(seed)
            else:
                resample(c, seed)

    model.resample_dropout = resample.__get__(model)


def get_model(config):
    if config.model.class_name == "BertClassifier":
        model = BertClassifier(config.model.name)
    else:
        ModelClass = getattr(transformers, config.model.class_name)
        LOG.info(f"Loading model class {ModelClass} with name {config.model.name}")
        
        # if "GPT-2.7B" in config.model.name:
        #     model = ModelClass.from_pretrained(config.model.name, cache_dir='./hugging_cache', device_map="auto")
        # else:
        #     model = ModelClass.from_pretrained(config.model.name, cache_dir='./hugging_cache')
        
        
        model = ModelClass.from_pretrained(config.model.name, cache_dir='./hugging_cache', device_map="auto")

        # assert config.device == config.device_map[0]
        # print(f"len(config.device_map): {len(config.device_map)}")
        # if "GPT-2.7B" in config.model.name:
        #     pass
        # elif(len(config.device_map) == 1):
        #     model.to(config.device)
        # elif(len(config.device_map) == 2):
        #     assert config.device == config.device_map[0]
        #     device_map = {
        #         config.device_map[0]: [_ for _ in range(0, 14)],
        #         config.device_map[1]: [_ for _ in range(14, 32)],
        #     }
        #     model.parallelize(device_map=device_map)
        # elif(len(config.device_map) == 3):
        #     device_map = {
        #         config.device_map[0]: [_ for _ in range(0, 2)],
        #         config.device_map[1]: [_ for _ in range(2, 10)],
        #         config.device_map[2]: [_ for _ in range(10, 28)],
        #     }
        #     model.parallelize(device_map=device_map)
        # elif(len(config.device_map) == 4):
        #     device_map = {
        #         config.device_map[0]: [_ for _ in range(0, 1)],
        #         config.device_map[1]: [_ for _ in range(1, 4)],
        #         config.device_map[2]: [_ for _ in range(4, 7)],
        #         config.device_map[3]: [_ for _ in range(7, 28)]
        #     }
        #     model.parallelize(device_map=device_map)

    # if config.model.pt is not None:
    #     LOG.info(f"Loading model initialization from {config.model.pt}")
    #     state_dict = torch.load(config.model.pt, map_location="cpu")
    #
    #     try:
    #         model.load_state_dict(state_dict)
    #     except RuntimeError:
    #         LOG.info("Default load failed; stripping prefix and trying again.")
    #         state_dict = {re.sub("^model.", "", k): v for k, v in state_dict.items()}
    #
    #         model.load_state_dict(state_dict)
    #
    #     LOG.info("Loaded model initialization")

    if config.dropout is not None:
        n_reset = 0
        for m in model.modules():
            if isinstance(m, nn.Dropout):
                m.p = config.dropout
                n_reset += 1

            if hasattr(m, "dropout"):  # Requires for BART, which uses F.dropout
                if isinstance(m.dropout, float):
                    m.dropout = config.dropout
                    n_reset += 1

            if hasattr(m, "activation_dropout"):  # Requires for BART, which uses F.dropout
                if isinstance(m.activation_dropout, float):
                    m.activation_dropout = config.dropout
                    n_reset += 1

        LOG.info(f"Set {n_reset} dropout modules to p={config.dropout}")

    param_names = [n for n, _ in model.named_parameters()]
    bad_inner_params = [p for p in config.model.inner_params if p not in param_names]
    if len(bad_inner_params) != 0:
        raise ValueError(f"Params {bad_inner_params} do not exist in model of type {type(model)}.")

    if config.no_grad_layers is not None:
        if config.half:
            model.bfloat16()

        def upcast(mod):
            modlist = None
            for child in mod.children():
                if isinstance(child, nn.ModuleList):
                    assert modlist is None, f"Found multiple modlists for {mod}"
                    modlist = child
            if modlist is None:
                raise RuntimeError("Couldn't find a ModuleList child")

            LOG.info(f"Setting {len(modlist) - config.no_grad_layers} modules to full precision, with autocasting")
            modlist[config.no_grad_layers:].to(torch.float32)
            modlist[config.no_grad_layers] = CastModule(modlist[config.no_grad_layers])
            modlist[-1] = CastModule(modlist[-1], in_cast=torch.float32, out_cast=torch.bfloat16)

        parents = []
        if hasattr(model, "transformer"):
            parents.append(model.transformer)
        if hasattr(model, "encoder"):
            parents.append(model.encoder)
        if hasattr(model, "decoder"):
            parents.append(model.decoder)
        if hasattr(model, "model"):
            parents.extend([model.model.encoder, model.model.decoder])

        for t in parents:
            t.no_grad_layers = config.no_grad_layers
            if config.half and config.alg != "rep":
                upcast(t)

        if config.half and config.alg != "rep":
            idxs = []
            for p in config.model.inner_params:
                for comp in p.split('.'):
                    if comp.isdigit():
                        idxs.append(int(comp))
            max_idx, min_idx = str(max(idxs)), str(config.no_grad_layers)
            for pidx, p in enumerate(config.model.inner_params):
                comps = p.split('.')
                if max_idx in comps or min_idx in comps:
                    index = comps.index(max_idx) if max_idx in comps else comps.index(min_idx)
                    comps.insert(index + 1, 'underlying')
                    new_p = '.'.join(comps)
                    LOG.info(f"Replacing config.model.inner_params[{pidx}] '{p}' -> '{new_p}'")
                    config.model.inner_params[pidx] = new_p

    return model


def get_tokenizer(config):
    from transformers import GPT2Tokenizer
    tok_name = config.model.tokenizer_name if config.model.tokenizer_name is not None else config.model.name
    tokenizer = getattr(transformers, config.model.tokenizer_class).from_pretrained(tok_name, cache_dir='./hugging_cache')

    if isinstance(tokenizer, GPT2Tokenizer):
        tokenizer.pad_token_id  = tokenizer.eos_token_id
        tokenizer.padding_side = 'left'
        print('GPTTokenizer Detected, Set pad token id and left padding!!!')
    return tokenizer


if __name__ == '__main__':
    m = BertClassifier("bert-base-uncased")
    m(torch.arange(5)[None, :])
    import pdb; pdb.set_trace()
