# Create learnable parameters.
    
import copy
import abc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import hydra
from omegaconf import DictConfig

from typing import Optional, List, Dict, Union
from .base import Policy

def get_mask(mode, p, convert_dtype=True):
    if mode == 0:
        return p.to(torch.bfloat16)
    if mode == 1:
        return torch.sigmoid(p).to(torch.bfloat16)
    elif mode == 2:
        prob = torch.sigmoid(p)
        mask = (prob > 0.5).to(torch.bfloat16).to(p.device)
        return ((mask - prob).detach() + prob).to(torch.bfloat16)
    elif mode == 3:
        soft_mask = get_soft_mask(p.numel(), torch.sigmoid(p))
        hard_mask = torch.round(soft_mask)
        return StraightThrough.apply(hard_mask, soft_mask)
    else:
        raise NotImplementedError

class StraightThrough(torch.autograd.Function):
    @staticmethod
    def forward(ctx, hard, soft):
        return hard

    @staticmethod
    def backward(ctx, grad_output):
        return None, grad_output


def get_soft_mask(n, fraction):
    indices = torch.linspace(0, n-1, n, dtype=torch.bfloat16) + 1
    scaled_indices = indices.to(fraction.device) - fraction * n
    result = torch.clamp(scaled_indices, 0, 1)
    return 1. - result

class WeightedCombination(Policy):
    def __init__(
            self,
            base_params,
            decomposed_params,
            base_policy_cfg: Optional[Union[DictConfig, int]],
            params_paths: List[str],
            gpu,
            mode,
            norm_coeffs, 
            per_layer,
            init_values: Optional[List[float]]=None,
            **kwargs):
        # Create learnable parameters.
        nn.Module.__init__(self=self,)
        weights_dict_list: List[Dict[str, torch.Tensor]] = []
        if base_policy_cfg is None:
            base_policy = Policy(base_params=base_params, gpu=gpu, init_val=0,
                                 mode=1)
        elif isinstance(base_policy_cfg, DictConfig):
            base_policy: Policy = hydra.utils.instantiate(
                base_policy_cfg,
                base_params=base_params,
                decomposed_params=decomposed_params,
                gpu=gpu,
            )
        else:
            raise NotImplementedError
        
        with torch.no_grad():
            for i, load_ckpt in enumerate(params_paths):
                print(f'Loading checkpoint {i} at {load_ckpt}...')
                if 'learnable_params' in load_ckpt:
                    # legacy checkpoint
                    legacy_checkpoint = True
                    learnable_params = torch.load(load_ckpt)
                else:
                    state_dict = torch.load(load_ckpt, weights_only=True)
                    base_policy.load_state_dict(state_dict=state_dict)
                    learnable_params = base_policy.get_learnable_params()
                weights_dict_list.append({k: torch.detach_copy(p) for k, p in
                                        learnable_params.items()})

        self.num_weights_dict = len(weights_dict_list)
        if isinstance(mode, int):
            self.mode_list = [mode for _ in range(self.num_weights_dict)]
        else:
            self.mode_list = mode

        self.num_params_per_weight_dict = 0
        for _ in weights_dict_list[0]:
            self.num_params_per_weight_dict += 1

        self.num_params = self.num_weights_dict*self.num_params_per_weight_dict
        if init_values is None:
            init_values = torch.Tensor([
                1/self.num_weights_dict for _ in range(self.num_weights_dict)])
        else:
            assert len(init_values) == self.num_weights_dict
            init_values = torch.Tensor(init_values)
        self.learned_params_per_weight_dict = 1
        if per_layer:
            self.learned_params_per_weight_dict = (
                self.num_params_per_weight_dict)
        init_values = torch.stack(
                [init_values 
                 for i in range(self.learned_params_per_weight_dict)],
                 dim=1)
        if norm_coeffs:
            # normalize across different weight idxs (for all layers)
            init_values = init_values/torch.sum(init_values, axis=0)
        
        # num weight idxs x learned params_per_weight_idx
        self.adaptive_weights = torch.nn.Parameter(
            data=init_values, requires_grad=True,
        )

        self.parameter_keys = []
        self.original_params = {}
        for k, v in weights_dict_list[0].items():
            self.parameter_keys.append(k)
            self.original_params[k] = []
            for i, weight_dict in enumerate(weights_dict_list):
                weight_tensor = get_mask(
                    mode=self.mode_list[i], p=weight_dict[k])
                new_key = k.replace('.', '_')
                self.register_buffer(
                    f'weights_{i}_k_{new_key}', tensor=weight_tensor,
                )
                self.original_params[k].append(weight_tensor.to(device=gpu))

        self.norm = norm_coeffs
        self.per_layer = per_layer
        self.trainable_params = [self.adaptive_weights]

    # def explore(self, expl_coeff)
    def get_weight_to_combine(self, k, weights_dict_idx):
        new_key = k.replace('.', '_')
        return getattr(self, f'weights_{weights_dict_idx}_k_{new_key}')
    
    def get_coeff_per_layer(self,):
        if self.norm:
            adaptive_weights = self.adaptive_weights/self.adaptive_weights.sum(
                0)
        else:
            adaptive_weights = self.adaptive_weights
        weights_per_layer = adaptive_weights.expand(
            [self.num_weights_dict, self.num_params_per_weight_dict,])
        return weights_per_layer

    def get_learnable_params(self, detach=False):
        adaptive_coeff_per_layer = self.get_coeff_per_layer()
        output_params = {}
        for i, (k, vs) in enumerate(self.original_params.items()):
            cs_coeff = adaptive_coeff_per_layer[:, i]
            out = vs[0]*cs_coeff[0]
            for j, other_v in enumerate(vs[1:]):
                v_idx = j + 1
                out = out + other_v*cs_coeff[v_idx]
            output_params[k] = out
        return output_params
            
    
    def get_mask(self, p, convert_dtype=True):
        return p
    
    def record_state(self, metrics_to_log):
        avg_weights = self.adaptive_weights.mean(-1).detach().cpu().numpy()
        dict_to_log = {f'adaptive_weight/mean_across_params_w{i}': w
                       for i, w in enumerate(avg_weights.tolist())}
        metrics_to_log.update(**dict_to_log)
