# Create learnable parameters.
    
import copy
import abc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class Policy(nn.Module):
    def __init__(self, base_params, gpu, init_val, lr, **kwargs):
        # Create learnable parameters.
        super().__init__()
        self.learnable_params = {}
        self.num_params = 0
        for k, v in base_params.items():
            # each param initialized with small gaussian noise
            if 'mlp' in k:
                self.learnable_params[k] = torch.nn.Parameter(
                    data=(
                        torch.randn(
                            min(v.shape), device=gpu, dtype=torch.bfloat16,
                        ) * 0.01 + init_val
                    ), requires_grad=True,
                )
                self.num_params += self.learnable_params[k].numel()
        print(f'#params={self.num_params}')
        self.learnable_params_list = list(self.learnable_params.values())
        self.learnable_params_module_list = nn.ParameterList(
            self.learnable_params_list)
        self.optimizer = torch.optim.Adam(self.learnable_params_list, lr=lr)
    
    def get_learnable_params(self,):
        return self.learnable_params

    def update(self, max_grad_norm):
        torch.nn.utils.clip_grad_norm_(
            self.learnable_params_list, max_grad_norm)
        self.optimizer.step()
        self.optimizer.zero_grad()


class RandomFourierFeatures(nn.Module):
    def __init__(self, std, input_dim, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.input_dim = input_dim
        self.output_dim = embed_dim
        self.const = 2*np.pi
        random_samples = torch.randn(size=[input_dim, embed_dim])*std
        self.register_buffer('random_samples', random_samples)

    def forward(self, inputs):
        freq_embeddings = self.const*torch.matmul(inputs, self.random_samples)
        embeddings = torch.concat(
            (torch.cos(freq_embeddings), torch.sin(freq_embeddings)), dim=-1)
        return embeddings


class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        id = input_dim
        layers = []
        for i in range(num_layers - 1):
            od = hidden_dim
            layers.append(nn.Linear(in_features=id, out_features=od))
            layers.append(nn.ReLU())
            id = od
        layers.append(nn.Linear(in_features=id, out_features=output_dim))
        self.layers = layers
        self.layer_mod = nn.Sequential(*self.layers)
        

    def forward(self, inputs):
        return self.layer_mod(inputs)


class ParamEmbedding(nn.Module):
    def __init__(self, features_fn, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.convert_to_float = isinstance(features_fn, RandomFourierFeatures)
        self.features_fn = features_fn
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.use_mlp = num_layers > 0
        self.features_embed_dim = features_fn.output_dim
        if self.use_mlp:
            self.mlp = MLP(
                input_dim=self.features_embed_dim,
                hidden_dim=self.hidden_dim,
                output_dim=self.output_dim,
                num_layers=self.num_layers,
                )
        else:
            self.output_dim = self.features_embed_dim

    def forward(self, inputs):
        if self.convert_to_float:
            max_inputs = torch.max(inputs)
            # from -1 to 1
            inputs = (inputs*2 - max_inputs)/max_inputs
        embeddings = self.features_fn(inputs)
        if self.use_mlp:
            embeddings = self.mlp(embeddings)
        return embeddings



class EncoderBasedPolicy(nn.Module):
    def __init__(
            self,
            base_params,
            decomposed_params,
            layer_id_encoder,
            s_encoder,
            final_head_layers,
            final_head_hidden_dim,
            gpu,
            init_val,
            lr,
            use_s_encoder=False,
            s_index_embedding=False,
             **kwargs,
            ):

        super().__init__()
        self.use_s_encoder = use_s_encoder
        self.s_encoder = s_encoder
        self.layer_id_encoder = layer_id_encoder
        if isinstance(layer_id_encoder, nn.Embedding):
            self.layer_id_encoder_output_dim = layer_id_encoder.embedding_dim
        else:
            self.layer_id_encoder_output_dim = layer_id_encoder.output_dim

        # We will add the embedding from the layer type
        self.layer_type_encoding_dim = self.layer_id_encoder_output_dim
        if self.use_s_encoder:
            self.s_encoder_output_dim = s_encoder.output_dim
            self.layer_type_encoding_dim += self.s_encoder_output_dim
        
        self.num_layer_types = 3
        self.layer_type_encoder = nn.Embedding(
            self.num_layer_types, self.layer_type_encoding_dim)

        self.layer_type_ids_per_key = {}
        self.layer_nums_per_key = {}
        self.layer_type_to_id = dict(
            gate_proj=0,
            up_proj=1,
            down_proj=2,
        )
        self.s_vals_per_key = {}
        self.param_num_output = None

        self.learnable_params = {}
        self.num_params = 0
        self.ordered_keys = []
        self.param_num_per_key = {}
        for k, v in base_params.items():
            # each param initialized with small gaussian noise
            if 'mlp' in k:
                s = decomposed_params[f'{k}.S']
                self.out_dtype = s.dtype
                self.s_vals_per_key[k] = s
                self.ordered_keys.append(k)
                split_name = k.split('.')
                layer_num, _, layer_type, _ = split_name[-4:]
                self.layer_nums_per_key[k] = int(layer_num)
                layer_type_id = self.layer_type_to_id[layer_type]

                self.layer_type_ids_per_key[k] = layer_type_id
                param_num = torch.numel(s)

                # assert they all match
                if self.param_num_output is not None:
                    assert (
                        param_num == self.param_num_output)
                else:
                    self.param_num_output = param_num
                self.param_num_per_key[k] = param_num

        layer_num_ints = torch.tensor(
            [self.layer_nums_per_key[k] for k in self.ordered_keys])
        self.max_layer_num = torch.max(layer_num_ints).item()

        # layer_num_floats = layer_num_ints/self.max_layer_num
        self.register_buffer('layer_id_emb_tensor', layer_num_ints)

        # NOTE: should be all the same
        self.param_num_list = [
            self.param_num_per_key[k] for k in self.ordered_keys]
        
        s_emb_input_floats = []

        layer_type_emb_inputs = torch.LongTensor(
            [self.layer_type_ids_per_key[k] for k in self.ordered_keys])
        
        # layer_type_emb_inputs = torch.LongTensor(list(range(self.num_layer_types)))
        self.register_buffer('layer_type_emb_inputs', layer_type_emb_inputs)
        for k in self.ordered_keys:
            s = self.s_vals_per_key[k]
            if s_index_embedding:
                num_s = torch.numel(s)
                s_float_rep = torch.arange(num_s)/num_s-1
            else:
                s_float_rep = (s - torch.mean(s))/torch.std(s) + 1e-8
            s_emb_input_floats.append(s_float_rep)
        s_emb_tensor = torch.concat(s_emb_input_floats, dim=0)
        self.register_buffer('s_emb_tensor', s_emb_tensor)
        self.layer_num_enc_values = 0
        self.final_heads_dict = {}
        self.mlp = MLP(input_dim=self.layer_type_encoding_dim,
                        hidden_dim=final_head_hidden_dim,
                        output_dim=self.param_num_output,
                        num_layers=final_head_layers,
                    )                                           

        self.trainable_params = [
            p for p in self.parameters() if p.requires_grad]
        self.num_params = sum([p.numel() for p in self.trainable_params])
        self.optimizer = torch.optim.Adam(self.trainable_params, lr=lr)
        self.apply(self._init_weights)
        with torch.no_grad():
            self.mlp.layers[-1].bias.data.fill_(init_val)
        self.to(device=gpu)
        self.optimizer.zero_grad()
        print(f'#params={self.num_params}')


    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def get_learnable_params(self, detach=False):
        if self.use_s_encoder:
            raise NotImplementedError
        layer_id_emb_inputs = self.layer_id_emb_tensor

        # one for each layer block - num_layers x layer_id_emb_dim
        layer_id_emb = self.layer_id_encoder(layer_id_emb_inputs)
        layer_type_emb_inputs = self.layer_type_emb_inputs
        layer_type_emb = self.layer_type_encoder(layer_type_emb_inputs)#.unbind(dim=0)
        
        head_input = layer_type_emb + layer_id_emb
        head_output = self.mlp(head_input)
        out_dict = {k: head_output[i] for i, k in enumerate(self.ordered_keys)}
        if detach:
            out_dict = {k: v.detach() for k, v in out_dict.items()}
        return out_dict

    def update(self, max_grad_norm):
        torch.nn.utils.clip_grad_norm_(self.trainable_params, max_grad_norm)
        self.optimizer.step()
        self.optimizer.zero_grad()


class MultiHeadEncoderBasedPolicy(nn.Module):
    def __init__(
            self,
            base_params,
            decomposed_params,
            layer_id_encoder,
            s_encoder,
            final_head_layers,
            final_head_hidden_dim,
            gpu,
            init_val,
            lr,
            use_s_encoder=False,
            s_index_embedding=False,
             **kwargs,
            ):

        super().__init__()
        raise NotImplementedError
        self.use_s_encoder = use_s_encoder
        self.s_encoder = s_encoder
        self.layer_id_encoder = layer_id_encoder
        if isinstance(layer_id_encoder, nn.Embedding):
            self.layer_id_encoder_output_dim = layer_id_encoder.embedding_dim
        else:
            self.layer_id_encoder_output_dim = layer_id_encoder.output_dim

        # We will add the embedding from the layer type
        self.layer_type_encoding_dim = self.layer_id_encoder_output_dim
        if self.use_s_encoder:
            self.s_encoder_output_dim = s_encoder.output_dim
            self.layer_type_encoding_dim += self.s_encoder_output_dim
        
        self.num_layer_types = 3
        self.layer_type_encoder = nn.Embedding(
            self.num_layer_types, self.layer_type_encoding_dim)

        self.layer_type_ids_per_key = {}
        self.layer_nums_per_key = {}
        self.layer_type_to_id = dict(
            gate_proj=0,
            up_proj=1,
            down_proj=2,
        )
        self.s_vals_per_key = {}
        self.param_num_per_layer_type = dict()
        # Create learnable parameters.
        self.learnable_params = {}
        self.num_params = 0
        self.ordered_keys = []
        self.param_num_per_key = {}
        for k, v in base_params.items():
            # each param initialized with small gaussian noise
            if 'mlp' in k:
                s = decomposed_params[f'{k}.S']
                self.s_vals_per_key[k] = s
                self.ordered_keys.append(k)
                split_name = k.split('.')
                layer_num, _, layer_type, _ = split_name[-4:]
                self.layer_nums_per_key[k] = int(layer_num)
                layer_type_id = self.layer_type_to_id[layer_type]

                self.layer_type_ids_per_key[k] = layer_type_id
                param_num = torch.numel(s)
                if layer_type in self.param_num_per_layer_type:
                    assert (
                        param_num == self.param_num_per_layer_type[layer_type])
                else:
                    self.param_num_per_layer_type[layer_type] = param_num
                self.param_num_per_key[k] = param_num
                print('L - T - N')
                print(k)
                print(layer_type)

                print(layer_num)
                print(v.shape)
                print(param_num)

        layer_num_ints = torch.tensor(
            [self.layer_type_ids_per_key[k] for k in self.ordered_keys])
        self.max_layer_num = torch.max(layer_num_ints).item()

        # layer_num_floats = layer_num_ints/self.max_layer_num
        self.register_buffer('layer_id_emb_tensor', layer_num_ints)

        self.param_num_list = [
            self.param_num_per_key[k] for k in self.ordered_keys]
        s_emb_input_floats = []
        
        # layer_type_emb_inputs = [self.layer_type_ids_per_key[k] for k in self.ordered_keys]
        layer_type_emb_inputs = torch.LongTensor(list(range(self.num_layer_types)))
        self.register_buffer('layer_type_emb_inputs', layer_type_emb_inputs)
        for k in self.ordered_keys:
            # print(k)
            # print('S shape')
            # print(s) # VALIDATE VALUES are mon. decreasing
            s = self.s_vals_per_key[k]
            if s_index_embedding:
                num_s = torch.numel(s)
                s_float_rep = torch.arange(num_s)/num_s-1
            else:
                # s_tot = torch.sum(s)
                s_float_rep = (s - torch.mean(s))/torch.std(s) + 1e-8
            s_emb_input_floats.append(s_float_rep)
        s_emb_tensor = torch.concat(s_emb_input_floats, dim=0)
        self.register_buffer('s_emb_tensor', s_emb_tensor)
        self.layer_num_enc_values = 0 # TODO

        self.final_heads_dict = {}
        for k, v in self.param_num_per_layer_type.items():
            mlp = MLP(input_dim=self.layer_type_encoding_dim,
                                        hidden_dim=final_head_hidden_dim,
                                        output_dim=v,
                                        num_layers=final_head_layers,
                        )
            with torch.no_grad():
                mlp.layers[-1].bias.data.fill_(init_val)
            self.final_heads_dict[k] =  mlp

        self.final_head_module_dict = nn.ModuleDict(self.final_heads_dict)

        self.trainable_params = [
            p for p in self.parameters() if p.requires_grad]
        self.num_params = sum([p.numel() for p in self.trainable_params])
        self.optimizer = torch.optim.Adam(self.trainable_params, lr=lr)
        self.apply(self._init_weights)
        self.to(device=gpu)
        print(f'#params={self.num_params}')


    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def get_learnable_params(self,):
        if self.use_s_encoder:
            raise NotImplementedError
        layer_id_emb_inputs = self.layer_id_emb_tensor

        print(layer_id_emb_inputs)
        # one for each layer block - num_layers x layer_id_emb_dim
        layer_id_emb = self.layer_id_encoder(layer_id_emb_inputs)
        print(layer_id_emb[:3])
        # layer_id_emb_rep = torch.repeat_interleave(
        #     layer_id_emb,
        #     repeats=self.num_layer_types,
        #     dim=0)
        layer_type_emb_inputs = self.layer_type_emb_inputs
        layer_type_emb_list = self.layer_type_encoder(layer_type_emb_inputs)#.unbind(dim=0)

        outputs_per_layer_type = []
        for i, (layer_type, head) in enumerate(self.final_heads_dict.items()):
            # broadcast along layer_dimension
            head_input = layer_type_emb_list[i] + layer_id_emb
            head_output = head(head_input)
            outputs_per_layer_type.append(head_output)

        out_dict = {}
        for k in self.ordered_keys:
            layer_num = self.layer_nums_per_key[k]
            layer_type = self.layer_type_ids_per_key[k]
            print(layer_type)
            print(layer_num)
            out_dict[k] = outputs_per_layer_type[layer_type][layer_num]
        return out_dict

    def update(self, max_grad_norm):
        torch.nn.utils.clip_grad_norm_(self.learnable_params_list, max_grad_norm)
        self.optimizer.step()
        self.optimizer.zero_grad()