from typing import Optional, List

import copy
import torch
import torch.nn as nn
import torch.nn.functional as F

from Pruner.mask.mask_module import OurMask
from Pruner.mask.model import prune_linear_layer_direct_to_gpu
from Pruner.mask.utils import initialize_wanda, initialize_FLAP
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention, OPTDecoder, OPTForCausalLM
from transformers.pytorch_utils import find_pruneable_heads_and_indices


class Mask_OPTAttention(OPTAttention):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
    ):
        super().__init__(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, is_decoder=is_decoder, bias=bias)
        self.pruned_heads = set()

    def prune_params(self, zs_block):
        head_z = None
        head_layer_z = None
        if "head_z" in zs_block:
            head_z = zs_block["head_z"].squeeze()
        
        to_prune_heads = self.turn_head_z(head_z, head_layer_z)
        len_to_prune_heads = len(to_prune_heads)
        if len_to_prune_heads == 0:
            return

        heads, index = find_pruneable_heads_and_indices(
            to_prune_heads, self.num_heads, self.head_dim, self.pruned_heads
        )
        qk_index = index; vo_index = index
        if len(index) == 0:
            self.k_proj = None
            self.v_proj = None
            self.q_proj = None
            self.out_proj = None
            torch.cuda.empty_cache()
        else:
            bhalf = next(self.q_proj.parameters()).dtype == torch.bfloat16
            self.q_proj = prune_linear_layer_direct_to_gpu(self.q_proj, qk_index)
            self.k_proj = prune_linear_layer_direct_to_gpu(self.k_proj, qk_index)
            self.v_proj = prune_linear_layer_direct_to_gpu(self.v_proj, vo_index)
            self.out_proj = prune_linear_layer_direct_to_gpu(self.out_proj, vo_index, dim=1)
            if bhalf:
                self.q_proj = self.q_proj.to(dtype=torch.bfloat16)
                self.k_proj = self.k_proj.to(dtype=torch.bfloat16)
                self.v_proj = self.v_proj.to(dtype=torch.bfloat16)
                self.out_proj = self.out_proj.to(dtype=torch.bfloat16)
            torch.cuda.empty_cache()
        # print(f"    Heads: {self.num_heads} -> {self.num_heads - len(heads)}")
        
        self.num_heads = self.num_heads - len(heads)
        self.embed_dim = self.num_heads * self.head_dim
        self.pruned_heads = self.pruned_heads.union(heads)
        
    def turn_head_z(self, head_z, head_layer_z):
        head_z = head_z.squeeze().clone()
        if head_layer_z is not None:
            head_z *= head_layer_z
        to_prune_heads = torch.where(head_z == 0)[0].view(-1).tolist()
        return to_prune_heads


class Mask_OPTDecoderLayer(OPTDecoderLayer):
    def __init__(self, config):
        super().__init__(config)
        self.ffn_dim = config.ffn_dim
        self.self_attn = Mask_OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            bias=config.enable_bias,
        )

    def prune_params(self, zs_block):
        self.self_attn.prune_params(zs_block)
        intermediate_z = zs_block.get("intermediate_z", None)
        mlp_z = zs_block.get("mlp_z", None)

        keep_dim = self.turn_mlp_z(intermediate_z, mlp_z)
        device = self.fc1.weight.device
        if len(keep_dim) == self.fc1.weight.shape[0]:
            return

        if len(keep_dim) == 0:
            self.fc1 = None
            self.fc2 = None
            torch.cuda.empty_cache()
        else:
            keep_dim_index = torch.tensor(keep_dim).long().to(device)
            bhalf = next(self.fc1.parameters()).dtype
            self.fc1 = prune_linear_layer_direct_to_gpu(self.fc1, keep_dim_index, dim=0)
            self.fc2 = prune_linear_layer_direct_to_gpu(self.fc2, keep_dim_index, dim=1)
            if bhalf == torch.bfloat16:
                self.fc1 = self.fc1.to(dtype=torch.bfloat16)
                self.fc2 = self.fc2.to(dtype=torch.bfloat16)
            torch.cuda.empty_cache()
        # print(f"    FFN intermediate dim: {self.ffn_dim} -> {len(keep_dim)}")

    def turn_mlp_z(self, intermediate_z, mlp_z):
        intermediate_z_layer = intermediate_z.squeeze().clone()
        if mlp_z is not None:
            intermediate_z_layer *= mlp_z
        keep_intermediate_dims = torch.where(intermediate_z_layer != 0)[0].tolist()
        return keep_intermediate_dims
    
    def forward(self, hidden_states, attention_mask=None, 
                layer_head_mask=None, intermediate_mask = None,
                past_key_value=None, use_cache=False, output_attentions=False):
        
        residual = hidden_states

        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            layer_head_mask = layer_head_mask,
            output_attentions=output_attentions,
        )
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states

        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.self_attn_layer_norm(hidden_states)

        # Fully Connected
        hidden_states_shape = hidden_states.shape
        hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
        residual = hidden_states

        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
        if self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)

        hidden_states = self.fc1(hidden_states)

        if intermediate_mask is not None:
            hidden_states *= intermediate_mask

        hidden_states = self.activation_fn(hidden_states)

        hidden_states = self.fc2(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        hidden_states = (residual + hidden_states).view(hidden_states_shape)

        # 350m applies layer norm AFTER attention
        if not self.do_layer_norm_before:
            hidden_states = self.final_layer_norm(hidden_states)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs

class OPT(OPTDecoder):
    def __init__(self, config):
        super().__init__(config)
        self.layers = nn.ModuleList([Mask_OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])

    def prune_params(self, zs=None):
        for i, block in enumerate(self.layers):
            zs_block = self.get_zs_block(zs, i)
            block.prune_params(zs_block)
            torch.cuda.empty_cache()
    
    def get_zs_block(self, zs, block_idx):
        zs_block = {}
        if zs is not None:
            for key in zs:
                zs_block[key] = zs[key][block_idx]
        return zs_block

    @torch.inference_mode()
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        mask_module: Optional[OurMask] = None,
        ppl_during_train: Optional[bool] = False,
    ):
        
        zs = None
        # prepare head_mask and intermediate_mask
        if mask_module is not None and not ppl_during_train:
            zs, grads = mask_module()
        elif ppl_during_train:
            zs, _ = mask_module(ppl_during_train=True)

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        batch_size, seq_length = input_shape
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
        # required mask seq length can be calculated via length of past
        mask_seq_length = past_key_values_length + seq_length

        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
        elif attention_mask.shape[1] != mask_seq_length:
            raise ValueError(
                f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
                f"{mask_seq_length} (sum of the lengths of current and past inputs)"
            )
        causal_attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, input_shape, inputs_embeds, past_key_values_length
        )
        pos_embeds = self.embed_positions(attention_mask, past_key_values_length)

        if self.project_in is not None:
            inputs_embeds = self.project_in(inputs_embeds)

        hidden_states = inputs_embeds + pos_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                # logger.warning_once(
                #     "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                # )
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        # check if head_mask has a correct number of layers specified if desired
        # for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
        #     if attn_mask is not None:
        #         if attn_mask.size()[0] != (len(self.layers)):
        #             raise ValueError(
        #                 f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
        #                 f" {head_mask.size()[0]}."
        #             )
        head_mask = None
        intermediate_mask = None
        for idx, decoder_layer in enumerate(self.layers):
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            
            if zs is not None:
                zs_block = self.get_zs_block(zs, idx)
                head_mask = zs_block.get("head_z", None)
                head_mask = head_mask.squeeze()
                intermediate_mask = zs_block.get("intermediate_z", None)
                intermediate_mask = intermediate_mask.squeeze()
            
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:
                    continue

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_attention_mask,
                # layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                layer_head_mask=head_mask,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                intermediate_mask=intermediate_mask,
            )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)

        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)

        if mask_module is not None and not ppl_during_train:
            return {"logits": hidden_states, "grads": grads}
        else:
            return {"logits": hidden_states}
    
    def instantation_test_forward(self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        zs: Optional[dict] = None,
    ):

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        batch_size, seq_length = input_shape
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
        # required mask seq length can be calculated via length of past
        mask_seq_length = past_key_values_length + seq_length

        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
        elif attention_mask.shape[1] != mask_seq_length:
            raise ValueError(
                f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
                f"{mask_seq_length} (sum of the lengths of current and past inputs)"
            )
        causal_attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, input_shape, inputs_embeds, past_key_values_length
        )
        pos_embeds = self.embed_positions(attention_mask, past_key_values_length)

        if self.project_in is not None:
            inputs_embeds = self.project_in(inputs_embeds)

        hidden_states = inputs_embeds + pos_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                # logger.warning_once(
                #     "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                # )
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        # check if head_mask has a correct number of layers specified if desired
        # for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
        #     if attn_mask is not None:
        #         if attn_mask.size()[0] != (len(self.layers)):
        #             raise ValueError(
        #                 f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
        #                 f" {head_mask.size()[0]}."
        #             )
        for idx, decoder_layer in enumerate(self.layers):
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            
            zs_block = self.get_zs_block(zs, idx)
            head_mask = zs_block.get("head_z", None)
            head_mask = head_mask.squeeze()
            intermediate_mask = zs_block.get("intermediate_z", None)
            intermediate_mask = intermediate_mask.squeeze()

            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.training:
                dropout_probability = torch.rand([])
                if dropout_probability < self.layerdrop:
                    continue

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_attention_mask,
                # layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                layer_head_mask=head_mask,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                intermediate_mask=intermediate_mask,
            )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        if self.final_layer_norm is not None:
            hidden_states = self.final_layer_norm(hidden_states)

        if self.project_out is not None:
            hidden_states = self.project_out(hidden_states)

        return {"logits": hidden_states}


class Masked_OPTForCausalLM(OPTForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model.decoder = OPT(config)
    
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        logits = self.lm_head(outputs[0]).contiguous()
        return logits


class Masked_OPT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.opt_model = Masked_OPTForCausalLM.from_pretrained(cfg.file_path)
        self.model = self.opt_model.model.decoder
        self.model.cfg = cfg
        self.lm_head = self.opt_model.lm_head

        self.mask_module = None
        if getattr(self.cfg, "mask", None) is not None:
            self.mask_module = OurMask(cfg, device=cfg.init_device)
    
    def initialize_score(self, method, tokenizer, device, model_name):
        if method == 'wanda':
            initialize_wanda(self.model, tokenizer, device, self.mask_module, model_name)
        elif method == 'flap':
            initialize_FLAP(self.model, tokenizer, device, self.mask_module, model_name)
        elif method == 'mean':
            return
    
    def instantation(self, origin_model, device, test_mask, test_batch):
        self.model = copy.deepcopy(origin_model)
        torch.cuda.empty_cache()
        self.model = self.model.to(device=device, dtype=torch.bfloat16)

        if test_mask:
            test_batch = test_batch.to(device)
            zs_cache = []
            zs_loss = []
            sample_num = 5
            for _ in range(sample_num):
                zs_cache.append(self.mask_module()[0])
                output = self.model.instantation_test_forward(input_ids=test_batch, zs=zs_cache[-1])
                zs_loss.append(self.loss(output, test_batch))
            # print(zs_loss)
            min_loss_index = zs_loss.index(min(zs_loss))
            zs = zs_cache[min_loss_index]
        else:
            zs, _ = self.mask_module()
        self.model.prune_params(zs)
        return copy.deepcopy(self.mask_module.masks), zs
    
    def sim_instantation(self, device, test_mask, test_batch):
        if test_mask:
            test_batch = test_batch.to(device)
            zs_cache = []
            zs_loss = []
            sample_num = 5
            for _ in range(sample_num):
                zs_cache.append(self.mask_module()[0])
                output = self.model.instantation_test_forward(input_ids=test_batch, zs=zs_cache[-1])
                zs_loss.append(self.loss(output, test_batch))
            # print(zs_loss)
            min_loss_index = zs_loss.index(min(zs_loss))
            zs = zs_cache[min_loss_index]
        else:
            zs, _ = self.mask_module()

        return copy.deepcopy(self.mask_module.masks), zs

    def forward(self, input_ids, instantation_model, outdated_zs=None):
        if instantation_model:
            input_ids = input_ids

            def calculate_grad(zs,  score):
                return (zs - score) / torch.sqrt((score + 1e-8) * (1 - score + 1e-8))
        
            grads = {f"{pruning_module}_grad": [] for pruning_module in self.mask_module.pruning_modules}
            for pruning_module in self.mask_module.pruning_modules:
                mask = self.mask_module.masks[pruning_module]
                if pruning_module == 'layer':
                    zs = outdated_zs['head_layer_z']
                else:
                    zs = outdated_zs[f'{pruning_module}_z']
                zs = zs.reshape(mask.mask_shape)
                grad = calculate_grad(zs, mask.score.data)
                grads[f"{pruning_module}_grad"] = grad

            logits = self.model.forward(input_ids=input_ids)
            logits["grads"] = grads
            return logits
        else:
            input_ids = input_ids
            model_output = self.model.forward(input_ids=input_ids, mask_module=self.mask_module)
            return model_output
        
    def sim_forward(self, input_ids, outdated_zs):
        output = self.model.instantation_test_forward(input_ids=input_ids, zs=outdated_zs)
        def calculate_grad(zs, score):
            return (zs - score) / torch.sqrt((score + 1e-8) * (1 - score + 1e-8))

        grads = {f"{pruning_module}_grad": [] for pruning_module in self.mask_module.pruning_modules}
        for pruning_module in self.mask_module.pruning_modules:
            mask = self.mask_module.masks[pruning_module]
            if pruning_module == 'layer':
                zs = outdated_zs['head_layer_z']
            else:
                zs = outdated_zs[f'{pruning_module}_z']
            zs = zs.reshape(mask.mask_shape)
            grad = calculate_grad(zs, mask.score.data)
            grads[f"{pruning_module}_grad"] = grad

        output["grads"] = grads
        return output

    def loss(self, outputs, batch):
        logits = self.lm_head(outputs["logits"]).contiguous()
        targets = self.get_targets(batch)

        loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
                                   targets.view(-1),
                                   ignore_index=-100)
        return loss

    def get_targets(self, batch):
        targets = torch.roll(batch, shifts=-1)
        targets[:, -1] = -100
        return targets
