import math
import os
import random
import warnings
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
from rxnfp.models import SmilesClassificationModel, SmilesTokenizer
from torch.nn.modules.normalization import LayerNorm
from tqdm.auto import tqdm, trange
from transformers import Adafactor, AdamW, BertConfig, BertForSequenceClassification, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup
from simpletransformers.config.model_args import ClassificationArgs
from transformers.models.bert.configuration_bert import BertConfig
from transformers.models.bert.modeling_bert import BertModel
from simpletransformers.classification.classification_model import (
    MODELS_WITHOUT_SLIDING_WINDOW_SUPPORT,
    MODELS_WITHOUT_CLASS_WEIGHTS_SUPPORT)
import yaml

from Parrot.models.model_layer import PositionalEncoding, TokenEmbedding, TransformerDecoderLayer, TransformerDecoder
from Parrot.models.utils import ConditionWithTempDataset, ConditionWithTextDataset

try:
    import wandb
    wandb_available = True
except ImportError:
    wandb_available = False

BOS, EOS, PAD, MASK = '[BOS]', '[EOS]', '[PAD]', '[MASK]'


class ParrotConditionModel(BertForSequenceClassification):

    def __init__(
        self,
        config,
    ) -> None:
        # super(ConditionModel).__init__()
        super().__init__(config)

        self.num_labels = config.num_labels
        self.config = config
        num_decoder_layers = config.num_decoder_layers
        nhead = config.nhead
        tgt_vocab_size = config.tgt_vocab_size
        dim_feedforward = config.dim_feedforward
        dropout = config.dropout
        d_model = config.d_model
        self.use_temperature = config.use_temperature
        if hasattr(config, 'output_attention'):
            self.output_attention = config.output_attention
        else:
            self.output_attention = False

        self.bert = BertModel(config)
        activation = F.relu
        layer_norm_eps = 1e-5
        factory_kwargs = {'device': None, 'dtype': None}

        decoder_layer = TransformerDecoderLayer(
            d_model,
            nhead,
            dim_feedforward,
            dropout,
            activation,
            layer_norm_eps,
            batch_first=True,
            norm_first=False,
            **factory_kwargs,
            output_attention=self.output_attention)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size=d_model)
        self.positional_encoding = PositionalEncoding(emb_size=d_model,
                                                      dropout=dropout)
        decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.decoder = TransformerDecoder(
            decoder_layer,
            num_decoder_layers,
            decoder_norm,
            output_attention=self.output_attention)
        self.generator = nn.Linear(d_model, tgt_vocab_size)

        self.loss_fn = torch.nn.CrossEntropyLoss(
            ignore_index=config.condition_label_mapping[1][PAD])
        if self.use_temperature:
            self.memory_regression_layer = nn.Sequential(
                nn.Linear(
                    self.config.max_position_embeddings *
                    self.config.hidden_size, d_model),
                nn.ReLU(),
            )
            self.regression_layer1 = nn.Sequential(
                nn.Linear(d_model * 5, d_model), nn.ReLU())
            self.regression_layer2 = nn.Linear(2 * d_model, 1)
            self.reg_loss_fn = torch.nn.MSELoss()

    def forward(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        label_input=None,
        label_mask=None,
        label_padding_mask=None,
        labels=None,
        memory_key_padding_mask=None,
        temperature=None,
        text=None,
        **kwargs
    ):
        if memory_key_padding_mask is None:
            memory_key_padding_mask = (attention_mask == 0)
        
        memory_unpool = self.bert(input_ids,
                           attention_mask=attention_mask,
                           token_type_ids=token_type_ids)[0]
        
        memory_pool = self.bert(input_ids,
                           attention_mask=attention_mask,
                           token_type_ids=token_type_ids)[1]
        condition_position_encoding = self.positional_encoding(self.tgt_tok_emb(label_input))

        outs, attention_weights = self.decoder(
            condition_position_encoding,
            memory_unpool,
            tgt_mask=label_mask,
            tgt_key_padding_mask=label_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask)
        logits = self.generator(outs)

        labels_out = labels[:, 1:]
        loss = self.loss_fn(logits.reshape(-1, logits.shape[-1]),
                            labels_out.reshape(-1))

        if self.use_temperature:
            temp_memory = memory_pool.reshape(
                -1,
                self.config.max_position_embeddings * self.config.hidden_size)
            temp_memory = self.memory_regression_layer(temp_memory)

            temp_out = outs[:, :-1, :]
            temp_out = temp_out.reshape(-1,
                                        temp_out.size(1) * temp_out.size(2))
            temp_out = self.regression_layer1(temp_out)

            temp_out = torch.cat([temp_memory, temp_out], dim=1)
            temp_out = self.regression_layer2(temp_out)

            loss_reg = self.reg_loss_fn(temp_out.reshape(-1),
                                        temperature.reshape(-1))

            return loss, logits, attention_weights, loss_reg, temp_out, condition_position_encoding, memory_pool, memory_unpool
        return loss, logits, attention_weights, condition_position_encoding, memory_pool, memory_unpool

    def encode(self, input_ids):
        return self.bert(input_ids)[0]

    def decode(self, tgt, memory, tgt_mask, memory_key_padding_mask):
        decoder_output, attention_weightes = self.decoder(
            self.positional_encoding(self.tgt_tok_emb(tgt)),
            memory,
            tgt_mask,
            memory_key_padding_mask=memory_key_padding_mask)
        return decoder_output, attention_weightes

    def decode_temperature(self, memory, decoder_output):
        temp_memory = memory.reshape(
            -1, self.config.max_position_embeddings * self.config.hidden_size)
        temp_memory = self.memory_regression_layer(temp_memory)
        temp_out = decoder_output
        temp_out = temp_out.reshape(-1, temp_out.size(1) * temp_out.size(2))
        temp_out = self.regression_layer1(temp_out)

        temp_out = torch.cat([temp_memory, temp_out], dim=1)
        temp_out = self.regression_layer2(temp_out)
        return temp_out


class ParrotConditionPredictionModel(SmilesClassificationModel):

    def __init__(
        self,
        model_type,
        model_name,
        tokenizer_type=None,
        tokenizer_name=None,
        weight=None,
        args=None,
        use_cuda=True,
        cuda_device=-1,
        freeze_encoder=False,
        freeze_all_but_one=False,
        **kwargs,
    ):

        MODEL_CLASSES = {
            "bert": (BertConfig, ParrotConditionModel, SmilesTokenizer),
        }

        if model_type not in MODEL_CLASSES.keys():
            raise NotImplementedException(
                f"Currently the following model types are implemented: {MODEL_CLASSES.keys()}"
            )

        self.args = self._load_model_args(model_name)

        decoder_args = args['decoder_args']
        try:
            self.condition_label_mapping = decoder_args[
                'condition_label_mapping']
        except:
            print('Warning: condition_label_mapping is not set!')

        if isinstance(args, dict):
            self.args.update_from_dict(args)
        elif isinstance(args, ClassificationArgs):
            self.args = args

        if (model_type in MODELS_WITHOUT_SLIDING_WINDOW_SUPPORT
                and self.args.sliding_window):
            raise ValueError(
                "{} does not currently support sliding window".format(
                    model_type))

        if self.args.thread_count:
            torch.set_num_threads(self.args.thread_count)
        if "sweep_config" in kwargs:
            self.is_sweeping = True
            sweep_config = kwargs.pop("sweep_config")
            sweep_values = sweep_config_to_sweep_values(sweep_config)
            self.args.update_from_dict(sweep_values)
        else:
            self.is_sweeping = False

        if self.args.manual_seed:
            random.seed(self.args.manual_seed)
            np.random.seed(self.args.manual_seed)
            torch.manual_seed(self.args.manual_seed)
            if self.args.n_gpu > 0:
                torch.cuda.manual_seed_all(self.args.manual_seed)

        config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]

        if tokenizer_type is not None:
            if isinstance(tokenizer_type, str):
                _, _, tokenizer_class = MODEL_CLASSES[tokenizer_type]
            else:
                tokenizer_class = tokenizer_type
        if model_name:
            self.config = config_class.from_pretrained(model_name,
                                                       **self.args.config)
        else:
            self.config = config_class(**self.args.config, **kwargs)
        self.num_labels = self.config.num_labels
        self.config.update(decoder_args)
        self.config.update({'use_temperature': args['use_temperature']})
        if 'ignore_mismatched_sizes' in args:
            kwargs.update(
                {'ignore_mismatched_sizes': args['ignore_mismatched_sizes']})
        if 'output_attention' in args:
            self.config.update({'output_attention': args['output_attention']})
        if model_type in MODELS_WITHOUT_CLASS_WEIGHTS_SUPPORT and weight is not None:
            raise ValueError(
                "{} does not currently support class weights".format(
                    model_type))
        else:
            self.weight = weight

        if use_cuda:
            if torch.cuda.is_available():
                self.device = torch.device(cuda_device)
                # print(f'current device: {cuda_device}')
            else:
                raise ValueError(
                    "'use_cuda' set to True when cuda is unavailable."
                    " Make sure CUDA is available or set use_cuda=False.")
        else:
            self.device = "cpu"
        if model_name:
            if not self.args.quantized_model:
                if self.weight:
                    self.model = model_class.from_pretrained(
                        model_name,
                        config=self.config,
                        weight=torch.Tensor(self.weight).to(self.device),
                        **kwargs,
                    )
                else:
                    self.model = model_class.from_pretrained(
                        model_name, config=self.config, **kwargs)
            else:
                quantized_weights = torch.load(
                    os.path.join(model_name, "pytorch_model.bin"))
                if self.weight:
                    self.model = model_class.from_pretrained(
                        None,
                        config=self.config,
                        state_dict=quantized_weights,
                        weight=torch.Tensor(self.weight).to(self.device),
                    )
                else:
                    self.model = model_class.from_pretrained(
                        None, config=self.config, state_dict=quantized_weights)

            if self.args.dynamic_quantize:
                self.model = torch.quantization.quantize_dynamic(
                    self.model, {torch.nn.Linear}, dtype=torch.qint8)
            if self.args.quantized_model:
                self.model.load_state_dict(quantized_weights)
            if self.args.dynamic_quantize:
                self.args.quantized_model = True
        else:
            self.model = model_class(config=self.config)
        if not hasattr(self.args, 'freeze_pretrain'):
            self.args.freeze_pretrain = False
        if self.args.freeze_pretrain:
            train_layers = [
                'tgt_tok_emb.embedding.weight', 'generator.weight',
                'generator.bias'
            ]
            print(f'Frozen load parameters, training {train_layers}')
            for p in self.model.named_parameters():
                if p[0] not in train_layers:
                    p[1].requires_grad = False
        if not hasattr(self.args, 'loss_equilibrium_constant'):
            self.args.loss_equilibrium_constant = 0.001
        self.results = {}

        if not use_cuda:
            self.args.fp16 = False

        if self.args.fp16:
            try:
                from torch.cuda import amp
            except AttributeError:
                raise AttributeError(
                    "fp16 requires Pytorch >= 1.6. Please update Pytorch or turn off fp16."
                )

        if tokenizer_name is None:
            tokenizer_name = model_name
        if hasattr(self.args, 'vocab_path'):
            pass
        else:
            self.args.vocab_path = None
        if tokenizer_name in [
                "vinai/bertweet-base",
                "vinai/bertweet-covid19-base-cased",
                "vinai/bertweet-covid19-base-uncased",
        ]:
            self.tokenizer = tokenizer_class.from_pretrained(
                tokenizer_name,
                do_lower_case=self.args.do_lower_case,
                normalization=True,
                **kwargs,
            )
        elif not self.args.vocab_path and not tokenizer_name in [
                "vinai/bertweet-base",
                "vinai/bertweet-covid19-base-cased",
                "vinai/bertweet-covid19-base-uncased",
        ]:
            self.tokenizer = tokenizer_class.from_pretrained(
                tokenizer_name,
                do_lower_case=self.args.do_lower_case,
                **kwargs)

        elif self.args.vocab_path:
            self.tokenizer = tokenizer_class(self.args.vocab_path,
                                             do_lower_case=False)
            model_to_resize = self.model.module if hasattr(
                self.model, "module") else self.model
            model_to_resize.resize_token_embeddings(len(self.tokenizer))

        if self.args.special_tokens_list:
            self.tokenizer.add_tokens(self.args.special_tokens_list,
                                      special_tokens=True)
            self.model.resize_token_embeddings(len(self.tokenizer))

        self.args.model_name = model_name
        self.args.model_type = model_type
        self.args.tokenizer_name = tokenizer_name
        self.args.tokenizer_type = tokenizer_type

        if model_type in ["camembert", "xlmroberta"]:
            warnings.warn(
                f"use_multiprocessing automatically disabled as {model_type}"
                " fails when using multiprocessing for feature conversion.")
            self.args.use_multiprocessing = False

        if self.args.wandb_project and not wandb_available:
            warnings.warn(
                "wandb_project specified but wandb is not available. Wandb disabled."
            )
            self.args.wandb_project = None

        if freeze_encoder:
            for name, param in self.model.named_parameters():
                if 'classifier' in name:
                    continue
                param.requires_grad = False
        elif freeze_all_but_one:
            n_layers = self.model.config.num_hidden_layers
            for name, param in self.model.named_parameters():
                if str(n_layers - 1) in name:
                    continue
                elif 'classifier' in name:
                    continue
                elif 'pooler' in name:
                    continue
                param.requires_grad = False

    def load_and_cache_examples(self,
                                examples,
                                evaluate=False,
                                no_cache=False,
                                multi_label=False,
                                verbose=True,
                                silent=False,
                                **kwargs):

        process_count = self.args.process_count

        tokenizer = self.tokenizer
        args = self.args

        if not no_cache:
            no_cache = args.no_cache

        if not multi_label and args.regression:
            output_mode = "regression"
        else:
            output_mode = "classification"

        if not no_cache:
            os.makedirs(self.args.cache_dir, exist_ok=True)

        mode = "dev" if evaluate else "train"
        self.args.use_multiprocessing = False
        self.args.use_multiprocessing_for_evaluation = False
        if not args.use_temperature: 
            dataset = ConditionWithTextDataset(
                examples,
                self.tokenizer,
                self.args,
                mode=mode,
                multi_label=multi_label,
                output_mode=output_mode,
                no_cache=no_cache,
                **kwargs
            )

        else:
            dataset = ConditionWithTempDataset(
                examples,
                self.tokenizer,
                self.args,
                mode=mode,
                multi_label=multi_label,
                output_mode=output_mode,
                no_cache=no_cache,
            )
        return dataset

    
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(
            (sz, sz), device=self.device)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(
            mask == 1, float(0.0))
        return mask

    def _get_inputs_dict(self, batch, device):
        if isinstance(batch[0], dict):
            inputs = {}
            texts = {
                key: value
                for key, value in batch[0].items() if isinstance(value,list)
            }
            inputs = {
                key: value.squeeze(1).to(device)
                for key, value in batch[0].items() if not isinstance(value,list)
            }

            inputs.update(texts)

            if not self.args.use_temperature:
                inputs["labels"] = batch[1].to(device)
                inputs["temperature"] = None
            else:
                inputs["labels"] = batch[1][0].to(device)
                inputs["temperature"] = batch[1][1].to(device)
        else:
            batch = tuple(t.to(device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3],
            }

        inputs['label_input'] = inputs['labels'][:, :-1]
        labels_seq_len = inputs['label_input'].shape[1]
        inputs['label_mask'] = self._generate_square_subsequent_mask(
            labels_seq_len)
        inputs['label_padding_mask'] = (
            inputs['label_input'] == self.condition_label_mapping[1]['[PAD]'])
        inputs['memory_key_padding_mask'] = (inputs['attention_mask'] == 0)
        return inputs

    def _idx2condition(self, idx):
        assert isinstance(idx, list)
        idx2condition_dict = self.condition_label_mapping[0]
        condition_list = [idx2condition_dict[x] for x in idx[1:]]
        return condition_list
