import os

import torch

from typing import Dict, Tuple

from framework import dataset
from framework.task import task, args, SimpleTask
from framework.helpers import TrainingHelper
import framework

from .helpers import gpt2_init, LMEvalMixin, LMTaskMixin


from layers import Transformer, LanguageModel, TransformerFFN, MoSA, PreLNSATransformerLayer, SymmetricRoutingAttention, RopePartheadSparseSlidingWindowAttention, PartheadRoutingAttention, FSA


@args
def a(parser: framework.helpers.ArgumentParser):
    parser.add_argument("-state_size", default=512)
    parser.add_argument("-dropout", default=0.0)
    parser.add_argument("-transformer.n_layers", default=12)
    parser.add_argument("-transformer.ff_multiplier", default=4.0)
    parser.add_argument("-transformer.universal.group_size", default=2)
    parser.add_argument("-transformer.p_drop_layer", default=0.0)
    parser.add_argument("-transformer.drop_layer_max", default="none", parser=parser.int_or_none_parser)
    parser.add_argument("-transformer.drop_layer_ignore_firstlast", default="none", parser=parser.int_or_none_parser)
    parser.add_argument("-transformer.n_heads", default=8)
    parser.add_argument("-transformer.head_projection_size", default="none", parser=parser.int_or_none_parser)
    parser.add_argument("-lm.orthemb_loss", default=0.0)
    parser.add_argument("-lm.unroll", default=512)
    parser.add_argument("-lm.norm_out_classifier", default=False)
    parser.add_argument("-lm.force_in_norm", default=False)
    parser.add_argument("-sa_moe.sw_heads", type=int, default=None)
    parser.add_argument("-sa_moe.sparsity_type", type=str, default=None)
    parser.add_argument("-sa_moe.num_sparse_heads", type=int, default=None)
    parser.add_argument("-sa_moe.total_heads", type=int, default=None)
    parser.add_argument("-sa_moe.detach_router_input_weight", type=float, default=0.0)
    parser.add_argument("-sa_moe.normalise", default=False)
    parser.add_argument("-sa_moe.baseline_dense_heads", default=0, type=int)
    parser.add_argument("-sa_moe.shared_dense_heads", default=0, type=int)
    parser.add_argument("-sa_moe.sparsity", default=1, type=float)
    parser.add_argument("-sa_moe.include_first", default=1, type=int)
    parser.add_argument("-sa_moe.symmetric", default=False, type=bool)
    parser.add_argument("-sa_moe.kv_ratio", default=1.0, type=float)
    parser.add_argument("-sa_moe.qkv_hidden", default=None, type=int)
    parser.add_argument("-sa_moe.custom_kernel", default=False, type=bool)
    parser.add_argument("-sa_moe.strided", default=0)
    parser.add_argument("-sa_moe.noise_std", default=0.0, type=float)
    parser.add_argument("-sa_moe.expert_heads", type=int, default=None)
    parser.add_argument("-sep_stream.norm_before_gate", default=False)
    parser.add_argument("-sep_stream.rmsnorm", default=False)
    parser.add_argument("-sentencepiece.n_pieces", default=8000)
    parser.add_argument("-lmds.n_validation_tokens", default=2000000)

class C4DatasetMixin:
    def create_datasets(self):
        self.batch_dim = 1

        if self.helper.args.stop_after is not None:
            train_token_limit = self.helper.args.lm.unroll * self.helper.args.batch_size * (self.helper.args.stop_after + 100)
        else:
            train_token_limit = None

        # Magic number for backward compatibility
        self.train_set = dataset.C4(
            self.helper.args.lm.unroll, split="train", n_tokens=self.helper.args.sentencepiece.n_pieces,
            cache_dir=os.environ.get("DATA_PATH", "./cache/"),
            token_limit=train_token_limit)
        self.valid_sets.val = dataset.C4(
            self.helper.args.lm.unroll, split="validation", n_tokens=self.helper.args.sentencepiece.n_pieces,
            cache_dir=os.environ.get("DATA_PATH", "./cache/"),
            token_limit=self.helper.args.lmds.n_validation_tokens)

        super().create_datasets()


@task()
class C4FlopmatchedMosaTransformer(C4DatasetMixin, LMEvalMixin, LMTaskMixin, SimpleTask):
    PAD_QUANTUM = 64


    def transformer_flops(self,h, h_prim, n, s):
        return 8 * h * h_prim * n * s + 4 * h_prim * n * s**2 + 2 * n * s * s + 16 * s * h * h

    def dense_attention_flops(self, h, h_prim, n, s):
        return 8 * h * h_prim * n * s + 4 * h_prim * n * s**2

    def mosa_attention_flops(self, h, h_prim, n, s, k):
        return 2 * h * n * s + n*k*h_prim + 8 * h * h_prim * n * k +\
            4 * h_prim * n * k * k
    def ff_flops(self, h, s):
        # assumes hidden layer is 4h
        return 16 * s * h * h
    def dense_trafo_layer_flops(self, h, h_prim, n, s):
        return self.dense_attention_flops(h, h_prim, n, s) + self.ff_flops(h, s)
    def mosa_trafo_layer_flops(self, h, h_prim, n, s, k):
        return self.mosa_attention_flops(h, h_prim, n, s, k) + self.ff_flops(h, s)
    def parthead_flops(self, h, h_prim, n_dense, n_mosa, s, k):
        return self.dense_attention_flops(h, h_prim, n_dense, s) + self.mosa_attention_flops(h, h_prim, n_mosa, s, k) + self.ff_flops(h, s)

    def highest_numheads_parthead(self, h, h_prim, n_dense, s, k, flops_limit):
        # finds the number of heads n_mosa biggest such that the flops of parthead are smaller than flops limit
        parthead_flops_cache = 0
        for i in range(10000):
            parthead_flops = self.parthead_flops(h, h_prim, n_dense, i, s, k)
            if parthead_flops > flops_limit:
                return i-1, parthead_flops_cache
            parthead_flops_cache = parthead_flops
        return None, None

    def __init__(self, helper: TrainingHelper):
        framework.task.SimpleTask.__init__(self, helper)
        LMTaskMixin.__init__(self)

        # find the number of sparse layers for the


    def create_layer(self, layer_num) -> torch.nn.Module:
        baseline_flops =self.dense_trafo_layer_flops(self.helper.args.state_size,
                                                     self.helper.args.transformer.head_projection_size,
                                                     self.helper.args.sa_moe.baseline_dense_heads,
                                                     self.helper.args.lm.unroll
                                                     )
        if self.helper.args.sa_moe.sparsity == 1:
            self.dense_heads = self.helper.args.sa_moe.baseline_dense_heads
            self.sparse_heads = 0
            self.k = self.helper.args.lm.unroll
            print('DENSE BASELINE')
        else:
            if self.helper.args.sa_moe.sparsity == 0:
                # we do dense but with sparse interface:
                self.k = self.helper.args.lm.unroll
                self.dense_heads = self.helper.args.sa_moe.shared_dense_heads
                self.sparse_heads = self.helper.args.sa_moe.baseline_dense_heads - self.dense_heads
            else:
                self.k = self.helper.args.lm.unroll // self.helper.args.sa_moe.sparsity
                self.dense_heads = self.helper.args.sa_moe.shared_dense_heads
                self.sparse_heads, parthead_flops = self.highest_numheads_parthead(self.helper.args.state_size,
                                                        self.helper.args.transformer.head_projection_size,
                                                        self.helper.args.sa_moe.shared_dense_heads,
                                                        self.helper.args.lm.unroll,
                                                        self.k,
                                                        baseline_flops
                )
                print('='*20)
                print('sparsity: ', self.helper.args.sa_moe.sparsity)
                print('num_sparse_heads: ', self.sparse_heads)
                print('flops ratio: ', parthead_flops / baseline_flops)

        return PreLNSATransformerLayer(
            attention=MoSA(
                h=self.helper.args.state_size,
                h_prim=self.helper.args.transformer.head_projection_size,
                mosa_heads=self.sparse_heads,
                hybrid_heads=self.dense_heads,
                max_seq_len=self.helper.args.lm.unroll,
                sparsity=self.helper.args.sa_moe.sparsity,
                hybrid_type='dense',
                include_first=self.helper.args.sa_moe.include_first
            ),
            ffn=TransformerFFN(
                d_model=self.helper.args.state_size,
                d_ff=int(self.helper.args.state_size * self.helper.args.transformer.ff_multiplier),
                d_out = self.helper.args.state_size,
            ),
            d_model=self.helper.args.state_size,
        )


    def create_inner_model(self) -> torch.nn.Module:
        return Transformer(
            self.create_layer,
            n_layers=self.helper.args.transformer.n_layers
        )

    def create_model(self) -> torch.nn.Module:
        model = super().create_model()

        if self.helper.args.gpt2_init:
            gpt2_init(model, self.helper.args.transformer.n_layers)

        return model

    def create_model(self) -> torch.nn.Module:
        self.validation_started_on = None
        model = LanguageModel(
            self.create_inner_model(),
            n_tokens=len(self.train_set.vocabulary),
            d_model=self.helper.args.state_size,
            n_layers=self.helper.args.transformer.n_layers,
            tied=self.helper.args.tied_embedding,
            in_norm=self.helper.args.lm.force_in_norm,
            out_norm=not self.helper.args.sep_stream.norm_before_gate,
        )
        return model




@task()
class C4FlopmatchedFsaTransformer(C4DatasetMixin, LMEvalMixin, LMTaskMixin, SimpleTask):
    PAD_QUANTUM = 64


    def transformer_flops(self,h, h_prim, n, s):
        return 8 * h * h_prim * n * s + 4 * h_prim * n * s**2 + 2 * n * s * s + 16 * s * h * h

    def dense_attention_flops(self, h, h_prim, n, s):
        return 8 * h * h_prim * n * s + 4 * h_prim * n * s**2

    def mosa_attention_flops(self, h, h_prim, n, s, k):
        return 2 * h * n * s + n*k*h_prim + 8 * h * h_prim * n * k +\
            4 * h_prim * n * k * k
    def ff_flops(self, h, s):
        # assumes hidden layer is 4h
        return 16 * s * h * h
    def dense_trafo_layer_flops(self, h, h_prim, n, s):
        return self.dense_attention_flops(h, h_prim, n, s) + self.ff_flops(h, s)
    def mosa_trafo_layer_flops(self, h, h_prim, n, s, k):
        return self.mosa_attention_flops(h, h_prim, n, s, k) + self.ff_flops(h, s)
    def parthead_flops(self, h, h_prim, n_dense, n_mosa, s, k):
        return self.dense_attention_flops(h, h_prim, n_dense, s) + self.mosa_attention_flops(h, h_prim, n_mosa, s, k) + self.ff_flops(h, s)

    def highest_numheads_parthead(self, h, h_prim, n_dense, s, k, flops_limit):
        # finds the number of heads n_mosa biggest such that the flops of parthead are smaller than flops limit
        parthead_flops_cache = 0
        for i in range(10000):
            parthead_flops = self.parthead_flops(h, h_prim, n_dense, i, s, k)
            if parthead_flops > flops_limit:
                return i-1, parthead_flops_cache
            parthead_flops_cache = parthead_flops
        return None, None

    def __init__(self, helper: TrainingHelper):
        framework.task.SimpleTask.__init__(self, helper)
        LMTaskMixin.__init__(self)

        # find the number of sparse layers for the


    def create_layer(self, layer_num) -> torch.nn.Module:
        baseline_flops =self.dense_trafo_layer_flops(self.helper.args.state_size,
                                                     self.helper.args.transformer.head_projection_size,
                                                     self.helper.args.sa_moe.baseline_dense_heads,
                                                     self.helper.args.lm.unroll
                                                     )
        if self.helper.args.sa_moe.sparsity == 1:
            self.dense_heads = self.helper.args.sa_moe.baseline_dense_heads
            self.sparse_heads = 0
            self.k = self.helper.args.lm.unroll
            print('DENSE BASELINE')
        else:
            if self.helper.args.sa_moe.sparsity == 0:
                # we do dense but with sparse interface:
                self.k = self.helper.args.lm.unroll
                self.dense_heads = self.helper.args.sa_moe.shared_dense_heads
                self.sparse_heads = self.helper.args.sa_moe.baseline_dense_heads - self.dense_heads
            else:
                self.k = self.helper.args.lm.unroll // self.helper.args.sa_moe.sparsity
                self.dense_heads = self.helper.args.sa_moe.shared_dense_heads
                self.sparse_heads, parthead_flops = self.highest_numheads_parthead(self.helper.args.state_size,
                                                        self.helper.args.transformer.head_projection_size,
                                                        self.helper.args.sa_moe.shared_dense_heads,
                                                        self.helper.args.lm.unroll,
                                                        self.k,
                                                        baseline_flops
                )
                print('='*20)
                print('sparsity: ', self.helper.args.sa_moe.sparsity)
                print('num_sparse_heads: ', self.sparse_heads)
                print('flops ratio: ', parthead_flops / baseline_flops)

        return PreLNSATransformerLayer(
            attention=FSA(
                h=self.helper.args.state_size,
                h_prim=self.helper.args.transformer.head_projection_size,
                mosa_heads=self.sparse_heads,
                hybrid_heads=self.dense_heads,
                max_seq_len=self.helper.args.lm.unroll,
                sparsity=self.helper.args.sa_moe.sparsity,
                hybrid_type='dense',
                include_first=self.helper.args.sa_moe.include_first
            ),
            ffn=TransformerFFN(
                d_model=self.helper.args.state_size,
                d_ff=int(self.helper.args.state_size * self.helper.args.transformer.ff_multiplier),
                d_out = self.helper.args.state_size,
            ),
            d_model=self.helper.args.state_size,
        )


    def create_inner_model(self) -> torch.nn.Module:
        return Transformer(
            self.create_layer,
            n_layers=self.helper.args.transformer.n_layers
        )

    def create_model(self) -> torch.nn.Module:
        model = super().create_model()

        if self.helper.args.gpt2_init:
            gpt2_init(model, self.helper.args.transformer.n_layers)

        return model

    def create_model(self) -> torch.nn.Module:
        self.validation_started_on = None
        model = LanguageModel(
            self.create_inner_model(),
            n_tokens=len(self.train_set.vocabulary),
            d_model=self.helper.args.state_size,
            n_layers=self.helper.args.transformer.n_layers,
            tied=self.helper.args.tied_embedding,
            in_norm=self.helper.args.lm.force_in_norm,
            out_norm=not self.helper.args.sep_stream.norm_before_gate,
        )
        return model



@task()
class C4SlidingWindowAndSparseTransformer(C4DatasetMixin, LMEvalMixin, LMTaskMixin, framework.task.SimpleTask):
    PAD_QUANTUM = 64

    def __init__(self, helper: TrainingHelper):
        framework.task.SimpleTask.__init__(self, helper)
        LMTaskMixin.__init__(self)


    def create_layer(self, layer_num) -> torch.nn.Module:
        return PreLNSATransformerLayer(
            attention=RopePartheadSparseSlidingWindowAttention(
                n_sw_heads=self.helper.args.sa_moe.sw_heads,
                n_sparse_heads=self.helper.args.sa_moe.expert_heads,
                sparsity=self.helper.args.sa_moe.sparsity,
                sparsity_type=self.helper.args.sa_moe.sparsity_type,
                h=self.helper.args.state_size,
                h_prim=self.helper.args.transformer.head_projection_size,
                max_seq_len=self.helper.args.lm.unroll,

            ),
            ffn=TransformerFFN(
                d_model=self.helper.args.state_size,
                d_ff=int(self.helper.args.state_size * self.helper.args.transformer.ff_multiplier),
                d_out = self.helper.args.state_size,
            ),
            d_model=self.helper.args.state_size,
        )

    def get_regularizers(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        loss, logs = super().get_regularizers()
        if self.helper.args.sa_moe.sparsity_type == 'routing':
            if self.training:
                rloss = 0

                for n, m in self.model.named_modules():
                    if isinstance(m, SymmetricRoutingAttention):
                        rloss += m.get_aux_loss()

            # alpha = self.helper.args.sa_moe.router_weight
            alpha = 1
            logs["aux_loss"] = rloss
            logs["clf_loss"] = loss
            loss = loss + rloss * alpha
        return loss, logs

    def create_inner_model(self) -> torch.nn.Module:
        return Transformer(
            self.create_layer,
            n_layers=self.helper.args.transformer.n_layers
        )

    def create_model(self) -> torch.nn.Module:
        model = super().create_model()

        if self.helper.args.gpt2_init:
            gpt2_init(model, self.helper.args.transformer.n_layers)

        return model


    def create_model(self) -> torch.nn.Module:
        self.validation_started_on = None
        model = LanguageModel(
            self.create_inner_model(),
            n_tokens=len(self.train_set.vocabulary),
            d_model=self.helper.args.state_size,
            n_layers=self.helper.args.transformer.n_layers,
            tied=self.helper.args.tied_embedding,
            in_norm=self.helper.args.lm.force_in_norm,
            out_norm=not self.helper.args.sep_stream.norm_before_gate,
        )
        return model




###
# Routing transformer
###

@task()
class C4FlopmatchedPartheadRoutingTransformer(C4DatasetMixin, LMEvalMixin, LMTaskMixin, framework.task.SimpleTask):
    PAD_QUANTUM = 64


    def transformer_flops(self,h, h_prim, n, s):
        return 8 * h * h_prim * n * s + 4 * h_prim * n * s**2 + 2 * n * s * s + 16 * s * h * h

    def dense_attention_flops(self, h, h_prim, n, s):
        return 8 * h * h_prim * n * s + 4 * h_prim * n * s**2

    def routing_attention_flops(self, h, h_prim, n, s, num_clusters):
        # n - number of heads, s - sequence length

        # routing attention cost
        # DOES NOT INCLUDE LAYER NORM COST
        k = s // num_clusters

        # symmetric so there is less head transformations!!
        QKV_map_flops = 6 * h * h_prim * n * s # 6 instead of 8 because Q=K
        routing_flops = 2 * h_prim * n * s
        attn_flops = 4 * h_prim * n * k * k * num_clusters
        return QKV_map_flops + routing_flops + attn_flops

    def ff_flops(self, h, s):
        # assumes hidden layer is 4h
        return 16 * s * h * h
    def dense_trafo_layer_flops(self, h, h_prim, n, s):
        return self.dense_attention_flops(h, h_prim, n, s) + self.ff_flops(h, s)

    def routing_trafo_layer_flops(self, h, h_prim, n, s, num_clusters):
        return self.routing_attention_flops(h, h_prim, n, s, num_clusters) + self.ff_flops(h, s)
    def parthead_flops(self, h, h_prim, n_dense, n_mosa, s, num_clusters):
        return self.dense_attention_flops(h, h_prim, n_dense, s) + self.routing_attention_flops(h, h_prim, n_mosa, s, num_clusters) + self.ff_flops(h, s)

    def highest_numheads_parthead(self, h, h_prim, n_dense, s, sparsity, flops_limit):
        # finds the number of heads n_mosa biggest such that the flops of parthead are smaller than flops limit
        parthead_flops_cache = 0
        for i in range(10000):
            parthead_flops = self.parthead_flops(h, h_prim, n_dense, i, s, sparsity)
            if parthead_flops > flops_limit:
                return i-1, parthead_flops_cache
            parthead_flops_cache = parthead_flops
        return None, None

    def __init__(self, helper: TrainingHelper):
        framework.task.SimpleTask.__init__(self, helper)
        LMTaskMixin.__init__(self)

    def get_regularizers(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        loss, logs = super().get_regularizers()
        if self.training:
            rloss = 0

            for n, m in self.model.named_modules():
                if isinstance(m, SymmetricRoutingAttention):
                    rloss += m.get_aux_loss()

        alpha = 1
        logs["aux_loss"] = rloss
        logs["clf_loss"] = loss
        loss = loss + rloss * alpha
        return loss, logs

    def create_layer(self, layer_num) -> torch.nn.Module:
        baseline_flops =self.dense_trafo_layer_flops(self.helper.args.state_size,
                                                     self.helper.args.transformer.head_projection_size,
                                                     self.helper.args.sa_moe.baseline_dense_heads,
                                                     self.helper.args.lm.unroll
                                                     )
        if self.helper.args.sa_moe.sparsity == 1:
            self.dense_heads = self.helper.args.sa_moe.baseline_dense_heads
            self.sparse_heads = 0
            self.k = self.helper.args.lm.unroll
            print('DENSE BASELINE')
        else:
            self.k = int(self.helper.args.lm.unroll // self.helper.args.sa_moe.sparsity)
            self.dense_heads = self.helper.args.sa_moe.shared_dense_heads
            self.sparse_heads, parthead_flops = self.highest_numheads_parthead(self.helper.args.state_size,
                                                    self.helper.args.transformer.head_projection_size,
                                                    self.helper.args.sa_moe.shared_dense_heads,
                                                    self.helper.args.lm.unroll,
                                                    self.helper.args.sa_moe.sparsity,
                                                    baseline_flops
            )
            print('='*20)
            print('sparsity: ', self.helper.args.sa_moe.sparsity)
            print('num_sparse_heads: ', self.sparse_heads)
            print('flops ratio: ', parthead_flops / baseline_flops)


        return PreLNSATransformerLayer(

            attention=PartheadRoutingAttention(
                h=self.helper.args.state_size,
                n_dense_heads=self.dense_heads,
                num_clusters=int(self.helper.args.sa_moe.sparsity),
                n_routing_heads=self.sparse_heads,
                h_prim=self.helper.args.transformer.head_projection_size,
                sparsity_dest=self.helper.args.sa_moe.sparsity,
                max_seq_len=self.helper.args.lm.unroll,
            ),
            ffn=TransformerFFN(
                d_model=self.helper.args.state_size,
                d_ff=int(self.helper.args.state_size * self.helper.args.transformer.ff_multiplier),
                d_out = self.helper.args.state_size,
            ),
            d_model=self.helper.args.state_size,
        )

    def create_inner_model(self) -> torch.nn.Module:
        return Transformer(
            self.create_layer,
            n_layers=self.helper.args.transformer.n_layers,
        )

    def create_model(self) -> torch.nn.Module:
        model = super().create_model()

        if self.helper.args.gpt2_init:
            gpt2_init(model, self.helper.args.transformer.n_layers)

        return model

    def create_model(self) -> torch.nn.Module:
        self.validation_started_on = None
        model = LanguageModel(
            self.create_inner_model(),
            n_tokens=len(self.train_set.vocabulary),
            d_model=self.helper.args.state_size,
            n_layers=self.helper.args.transformer.n_layers,
            tied=self.helper.args.tied_embedding,
            in_norm=self.helper.args.lm.force_in_norm,
            out_norm=not self.helper.args.sep_stream.norm_before_gate,
        )
        return model