import copy

import torch
import torch.nn.functional as F

from .abstract import TargetNetwork
# from all2.nn import RLNetwork


# class FixedTarget(TargetNetwork):
#     def __init__(self, update_frequency):
#         self._encoder = None
#         self._decoder = None
#         self._target_enc = None
#         self._target_dec = None
#         self._updates = 0
#         self._update_frequency = update_frequency

#     def __call__(self, *inputs):
#         with torch.no_grad():
#             return self._target_enc(self._target_dec(*inputs))

#     def encode(self, *inputs):
#         with torch.no_grad():
#             return self._target_enc(*inputs)

#     def decode(self, *inputs):
#         with torch.no_grad():
#             return self._target_dec(*inputs)

#     def init(self, encoder, decoder):
#         self._encoder = RLNetwork(encoder)
#         self._decoder = decoder
#         self._target_enc = RLNetwork(copy.deepcopy(encoder))
#         self._target_dec = copy.deepcopy(decoder)

#     def update(self):
#         self._updates += 1
#         if self._should_update():
#             self._target_enc.load_state_dict(self._encoder.state_dict())
#             self._target_dec.load_state_dict(self._decoder.state_dict())

#     def _should_update(self):
#         return self._updates % self._update_frequency == 0

import copy

import torch

from .abstract import TargetNetwork


class FixedTarget(TargetNetwork):
    def __init__(self, update_frequency, n_mixture, n_actions):
        self._source_backbone = None
        self._source_mixing_head = None
        self._source_mean_head = None
        self._source_covariance_head = None
        self._target_backbone = None
        self._target_mixing_head = None
        self._target_mean_head = None
        self._target_covariance_head = None
        self._updates = 0
        self._update_frequency = update_frequency
        self.n_mixture = n_mixture
        self.n_actions = n_actions

    # def __call__(self, *inputs):
    #     with torch.no_grad():
    #         return self._target(*inputs)

    def get_features(self, states):
        with torch.no_grad():
            return self._target_backbone(states)    
    

    def get_mixing(self, features):
        with torch.no_grad():
            return F.softmax(self._target_mixing_head(features).view(-1, self.n_mixture), dim=-1)


    def get_means(self, features):
        with torch.no_grad():
            return self._target_mean_head(features).view(-1, self.n_mixture, self.n_actions)
    

    def augment_features(self, features1, features2):
        with torch.no_grad():
            diff = features1 - features2
            hadamard = features1 * features2

            return torch.cat((features1, features2, diff, hadamard), dim=-1)
    

    def get_covariances(self, features1, features2):
        with torch.no_grad():
            augmented_features = self.augment_features(features1, features2)
            params = self._target_covariance_head(augmented_features).view(-1, 1, self.n_actions*(self.n_actions+1)//2)
            B = params.shape[0]
            L = torch.zeros((B, 1, self.n_actions, self.n_actions), device=params.device)
            idx = torch.tril_indices(self.n_actions, self.n_actions)
            L[:, :, idx[0], idx[1]] = params
            # Exponential on the diagonal to ensure positive definiteness
            diag_idx = torch.arange(self.n_actions)
            L[:, :, diag_idx, diag_idx] = torch.exp(L[:, :, diag_idx, diag_idx]/2) + 1e-1

            return L @ L.transpose(-1, -2)


    def __call__(self, inputs):
        with torch.no_grad():
            features = self.get_features(inputs)
            features1 = features[:, 0, :]
            features2 = features[:, 1, :]
            mixing_coeffs1 = self.get_mixing(features1).unsqueeze(1)
            mixing_coeffs2 = self.get_mixing(features2).unsqueeze(1)
            means1 = self.get_means(features1).unsqueeze(1)
            means2 = self.get_means(features2).unsqueeze(1)
            covariances = self.get_covariances(features1, features2)
            return torch.cat((mixing_coeffs1, mixing_coeffs2), dim=1), torch.cat((means1, means2), dim=1), covariances
    

    def init(self, backbone, mixing_head, mean_head, covariance_head):
        self._source_backbone = backbone
        self._source_mixing_head = mixing_head
        self._source_mean_head = mean_head
        self._source_covariance_head = covariance_head
        self._target_backbone = copy.deepcopy(backbone)
        self._target_mixing_head = copy.deepcopy(mixing_head)
        self._target_mean_head = copy.deepcopy(mean_head)
        self._target_covariance_head = copy.deepcopy(covariance_head)

    def update(self):
        self._updates += 1
        if self._should_update():
            self._target_backbone.load_state_dict(self._source_backbone.state_dict())
            self._target_mixing_head.load_state_dict(self._source_mixing_head.state_dict())
            self._target_mean_head.load_state_dict(self._source_mean_head.state_dict())
            self._target_covariance_head.load_state_dict(self._source_covariance_head.state_dict())

    def _should_update(self):
        return self._updates % self._update_frequency == 0

