import torch
import torch.nn as nn
import numpy as np
import torch.functional as F

def evaluate_one_hot(decoded_material, n_layer, n_mat):
    if len(decoded_material.shape) < 2:
        decoded_material = decoded_material.unsqueeze(0)

    values = decoded_material[:, :-n_layer].reshape(-1,n_mat).square().sum(dim=1)
    mean = torch.mean(values).item()
    return mean



class Generator2(nn.Module):
    def __init__(self, n_layer, n_movements, device='cpu'):
        super().__init__()

        self.n_movements = n_movements
        self.n_layer = n_layer
        
        self.gen = nn.Sequential(
            nn.Linear(6*n_layer, 3*n_layer),
            nn.Tanh(),
            nn.Linear(3*n_layer, 3*n_layer),
            ).to(device)
        
        
    def _init_weights(self):
        ## initialize weights with a random normal distribution
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.02)

    def get_movements_number(self):
        return self.n_movements

    # inputs are tensors of batch x n_movements x 2*latent space dimension
    def forward(self, inputs):
        return self.gen(inputs)
    



class SelectionLayer2(nn.Module):
    def __init__(self, num_seeds, device='cpu'):
        super().__init__()
        self.device = device
        self.weights = nn.Parameter(torch.tensor([1/num_seeds for i in range(num_seeds)]), True)
        self.to(device)
        
    def get_params(self):
        return self.weights
    
    # inputs are tensors of batch size x num_seeds x latent space dimension
    def forward(self, inputs):
        alphas = torch.nn.functional.softmax(self.weights, dim=0).unsqueeze(1)
        s = torch.sum(alphas * inputs, dim=1)
        
        # returning a tensor batch size x latent space dimension
        return s



class GidNet2():
    
    def __init__(self, n_layer, n_mat, n_seeds, n_movements, lambda1, lambda2, onehot_weight, encoder, decoder, simulator, device, metamat_config):
        self.n_layer = n_layer
        self.n_mat = n_mat
        self.n_seeds = n_seeds
        self.encoder = encoder
        self.decoder = decoder
        self.simulator = simulator
        self.device = device
        self.n_movements = n_movements
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.onehot_weight = onehot_weight
        self.metamat_config = metamat_config
        self.best_srmse = float('inf')
        self.best_srmse_onehot = None
        self.best_srmse_one_hot_accuracy = None
        self.best_srmse_s_weights = None
        self.best_srmse_epoch = None
        self.best_srmse_loss = None
        self.local_bests_list = []

        self.enc_train_min = None
        self.enc_train_max = None
        
        self.result: torch.Tensor
        self.seeds: torch.Tensor

    def set_enc_train_min_max(self, enc_train_min, enc_train_max):
        self.enc_train_min = enc_train_min
        self.enc_train_max = enc_train_max

    def get_best_srmse(self):
        return self.best_srmse

    def get_result(self):
        return self.best_srmse_onehot

    def get_best_srmse_one_hot_accuracy(self):
        return self.best_srmse_one_hot_accuracy

    def get_best_srmse_s_weights(self):
        return self.best_srmse_s_weights

    def get_best_srmse_epoch(self):
        return self.best_srmse_epoch

    def get_best_srmse_loss(self):
        return self.best_srmse_loss

    def get_local_bests_list(self):
        return self.local_bests_list

    def compute_knn(self, X_train, y_train, y, k):
        dist = torch.norm(y_train.cpu() - y, dim=1, p=2)
        knn = dist.topk(k, largest=False)
        seeds = X_train[knn.indices]
        
        return seeds
    
    def _init_models(self):
        self.selection_layer = SelectionLayer2(num_seeds=self.n_seeds, device=self.device)
        self.generator = Generator2(self.n_layer, self.n_movements, self.device)
        self.best_srmse = float('inf')
        self.best_srmse_onehot = None
        self.best_srmse_one_hot_accuracy = None
        self.best_srmse_s_weights = None
        self.best_srmse_epoch = None
        self.best_srmse_loss = None
        self.local_bests_list = []            

        
    def loss(self, y, y_desired, decoded_mat, weights):
        # The y has shape NUM_MOVEMENTS x 2400 (wave points)

        # Obtain the l2-norm for every point in NUM_Movements
        l2_y = torch.norm(y - y_desired, p=2, dim=1)
        
        #w_y = weights.square().sum(dim=0).mul(-1)
        
        #one-hot constraint
        #oh_c =  1 - decoded_mat[:, :-self.n_layer].reshape(-1, self.n_mat).square().sum(dim=1).reshape(-1, self.n_layer).mean(dim=1)

        # Recostruction error
        #l2_rec = torch.norm(self.decoder(self.encoder(decoded_mat)) - decoded_mat, p=2, dim=1)
        
        # Obtain a NUM_MOVEMENTS x 1 tensor
        total_loss = (l2_y * 1)# + (oh_c * self.onehot_weight) # + (l2_rec * self.lambda2)

        # take the index of the best point in the total loss
        best_idx = torch.argmin(total_loss)

        return total_loss.mean(), l2_y, total_loss, best_idx


    def train(self, y_desired, n_epoch, lr, quiet=True, semloss=None, LOG = False):
        # Latent dimension shape is 30 (3*self.n_layer)
        self._init_models()
        
        # Needed for statistical purposes
        initial_centroid = []

        latent_dim = self.n_layer * 3
        
        # Find the initial seeds: (n_seed, latent_dim)
        # We take the (min, max) for every component of the latent dimension and sample the seeds accordingly
        self.seeds = self.enc_train_min + (self.enc_train_max - self.enc_train_min) * torch.rand(self.n_seeds, latent_dim).to(self.device) # type: ignore

        y_desired = y_desired.to(self.device)

        optimizer = torch.optim.Adam([
            {"params": self.generator.parameters()},
            {"params": self.selection_layer.parameters()}],
            lr=lr
        )
          
        # Reshape in (batch_size, n_seed, latentspace_dim)
        self.seeds = torch.reshape(self.seeds, (1, self.seeds.shape[0], self.seeds.shape[1]))
        
        history = []
        for epoch in range(n_epoch):
            optimizer.zero_grad()
            
            # Returns (1, 30), the new point from where exploration starts
            s = self.selection_layer(self.seeds)
            if epoch == 0:
                initial_centroid = s

            #Random noise of size (1, 32, 30)
            z = torch.randn(size=[s.shape[0], self.n_movements, s.shape[1]]).to(self.device)
            
            # Not know if really needed (before there wasn't)
            tmp_s = torch.clone(s) 
            
            # Expand the seed to 32 movements in (1, 32, 30)
            tmp_s = tmp_s.expand(s.shape[0], self.n_movements, s.shape[1])
            
            #Concatenate the noise to each seed and call the generator
            out = self.generator(torch.cat((tmp_s,z), dim=2))
            
            # Expand the seed to 32 movements in (1, 32, 30)
            s = s.expand(s.shape[0], self.n_movements, s.shape[1])
            
            # Effectively compute the n_movements
            out = torch.add(s, out)
            
            # Reshape to (32, 30)
            out = torch.reshape(out, [self.n_movements, out.shape[2]])
            
            decoded_material = self.decoder(out)

            y = self.simulator(decoded_material)
            
            loss, y_loss, total_loss, best_mat_idx = self.loss(y, y_desired, decoded_material, self.selection_layer.get_params())
            
            epsilon_material = 0.005 + 0.99 * decoded_material 
            
            if semloss != None:                
                # Eliminate thickness and reshape such that each column is one complete assignment
                reshaped_mat = epsilon_material[:,:self.n_layer * self.n_mat].reshape(self.n_layer * self.n_mat, epsilon_material.shape[0])
                _, wmc = semloss(probabilities=reshaped_mat, output_wmc_per_sample=True)

                multi_sem_loss = -torch.log(wmc)
                sem_loss = multi_sem_loss.mean()
                loss += sem_loss

                idx = torch.argmax(multi_sem_loss)
                if torch.isinf(multi_sem_loss[idx]):
                    print("Infinite loss, exploding gradients")
                    print(f"At idx {idx}")
                    print(multi_sem_loss[idx])
                    print(epsilon_material[idx, :])
        
            loss.backward(retain_graph=True)

            #if (epoch == 0 or epoch == 199 or epoch % 20 == 0) and LOG:
            #    print(f"Epoch: {epoch}, SRMSE: {torch.sqrt((y_loss.min() * y_loss.min()) / 2001)}, ")
                #decoded_point_log = torch.tensor(np.copy(decoded_material[best_mat_idx].detach().cpu().numpy()))
                #print(f"Real SRMSE: {torch.sqrt((simulate_material(decoded_point_log) - y_desired).square().mean())}")
                
                #print(f"Selection layer weights: {self.selection_layer.get_params()}")
                #print(f"SOFTMAX Selection layer weights: {torch.softmax(self.selection_layer.get_params(), dim=0)}")
                #print()

            
            latent_point = np.copy(out[best_mat_idx].detach().cpu().numpy())
            decoded_point = np.copy(decoded_material[best_mat_idx].detach().cpu().numpy())
            rec_loss = y_loss[best_mat_idx].item() 
            sem_loss = multi_sem_loss[best_mat_idx].item() if semloss != None else "None" # type: ignore
            onehot = evaluate_one_hot(decoded_material[best_mat_idx], n_layer=self.n_layer, n_mat=self.n_mat)
            
            # Log only the best point
            history.append((rec_loss, sem_loss, onehot, total_loss[best_mat_idx].item(), latent_point, decoded_point))

            optimizer.step()

        return history, initial_centroid