import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from dataset_5_layer.data_utils.utils import evaluate_one_hot_gidnet, evaluate_single_srmse, scaler,unscaler
from dataset_5_layer.data_utils.real_simulator_tmm import srmse_evaluate
import numpy as np

def evaluate_one_hot(decoded_material, n_layer, n_mat):
    if len(decoded_material.shape) < 2:
        decoded_material = decoded_material.unsqueeze(0)

    values = decoded_material[:, :-n_layer].reshape(-1,n_mat).square().sum(dim=1)
    mean = torch.mean(values).item()
    return mean

class Generator(nn.Module):
    def __init__(self, n_layer, n_movements, device='cpu'):
        super().__init__()

        self.n_movements = n_movements
        self.n_layer = n_layer
        
        self.gen = nn.Sequential(
            nn.Linear(6*n_layer, 3*n_layer),
            nn.ReLU(),
            nn.Linear(3*n_layer, 3*n_layer),
            ).to(device)
        
        
    def _init_weights(self):
        ## initialize weights with a random normal distribution
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.02)

    def get_movements_number(self):
        return self.n_movements

    # inputs are tensors of batch x n_movements x 2*latent space dimension
    def forward(self, inputs):
        return self.gen(inputs)
    

class SelectionLayer(nn.Module):
    def __init__(self, num_seeds, device='cpu'):
        super().__init__()
        self.device = device

        self.weights = nn.Parameter(torch.tensor([1/num_seeds for i in range(num_seeds)]), True)
        self.to(device)
        
    def get_params(self):
        return self.weights
    
    # inputs are tensors of batch size x num_seeds x latent space dimension
    def forward(self, inputs):
        alphas = torch.nn.functional.softmax(self.weights, dim=0).unsqueeze(1)
        s = torch.sum(alphas * inputs, dim=1)
        
        # returning a tensor batch size x latent space dimension
        return s



class GidNet():
    
    def __init__(self, n_layer, n_mat, n_seeds, n_movements, lambda1, lambda2, onehot_weight, encoder, decoder, simulator, device, metamat_config):
        self.n_layer = n_layer
        self.n_mat = n_mat
        self.n_seeds = n_seeds
        self.encoder = encoder
        self.decoder = decoder
        self.simulator = simulator
        self.device = device
        self.n_movements = n_movements
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.onehot_weight = onehot_weight
        self.metamat_config = metamat_config
        self.best_srmse = float('inf')
        self.best_srmse_onehot = None
        self.best_srmse_one_hot_accuracy = None
        self.best_srmse_s_weights = None
        self.best_srmse_epoch = None
        self.best_srmse_loss = None
        self.local_bests_list = []

        self.enc_train_min = None
        self.enc_train_max = None

        self.result: torch.Tensor
        self.seeds: torch.Tensor

    def set_enc_train_min_max(self, enc_train_min, enc_train_max):
        self.enc_train_min = enc_train_min
        self.enc_train_max = enc_train_max

    def get_best_srmse(self):
        return self.best_srmse

    def get_result(self):
        return self.best_srmse_onehot

    def get_best_srmse_one_hot_accuracy(self):
        return self.best_srmse_one_hot_accuracy

    def get_best_srmse_s_weights(self):
        return self.best_srmse_s_weights

    def get_best_srmse_epoch(self):
        return self.best_srmse_epoch

    def get_best_srmse_loss(self):
        return self.best_srmse_loss

    def get_local_bests_list(self):
        return self.local_bests_list

    def compute_knn(self, X_train, y_train, y, k):
        dist = torch.norm(y_train.cpu() - y, dim=1, p=2)
        knn = dist.topk(k, largest=False)
        seeds = X_train[knn.indices]
        
        return seeds
    
    def _init_models(self):
        self.selection_layer = SelectionLayer(num_seeds=self.n_seeds, device=self.device)
        self.generator = Generator(self.n_layer, self.n_movements, self.device)
        self.best_srmse = float('inf')
        self.best_srmse_onehot = None
        self.best_srmse_one_hot_accuracy = None
        self.best_srmse_s_weights = None
        self.best_srmse_epoch = None
        self.best_srmse_loss = None
        self.local_bests_list = []            

        
    def loss(self, y, y_desired, decoded_mat, weights):
        # The y has shape NUM_MOVEMENTS x 2400 (wave points)

        # Obtain the l2-norm for every point in NUM_Movements
        l2_y = torch.norm(y - y_desired, p=2, dim=1)

        #w_y = weights.square().sum(dim=0).mul(-1)

        #one-hot constraint
        #oh_c =  1 - decoded_mat[:, :-self.n_layer].reshape(-1, self.n_mat).square().sum(dim=1).reshape(-1, self.n_layer).sum(dim=1)

        # Obtain a NUM_MOVEMENTS x 1 tensor
        total_loss = (l2_y * 1) # + (oh_c * self.onehot_weight)
     
        # take the index of the best point in the total loss
        best_idx = torch.argmin(total_loss)

        return torch.mean(total_loss), l2_y, total_loss, best_idx


    def train(self, y_desired, x_desired, initial_seeds, n_epoch, lr, semloss=None, LOG=False):
        self._init_models()

        # Needed for statistical purposes
        initial_centroid = []

        # type: ignore

        y_desired = y_desired.to(self.device)

        optimizer = torch.optim.Adam([
            {"params": self.generator.parameters()},
            {"params": self.selection_layer.parameters()}],
            lr=lr
        )
          
        # Reshape in (batch_size, n_seed, latentspace_dim)
        self.seeds = torch.reshape(initial_seeds, (1, initial_seeds.shape[0], initial_seeds.shape[1]))
        
        history = []
        for epoch in range(n_epoch):
            optimizer.zero_grad()
            # Returns (1, 15), the new point from where exploration starts
            s = self.selection_layer(self.seeds)
            
            if epoch == 0:
                initial_centroid = s

            #Random noise of size (1, 32, 15)
            z = torch.randn(size=[s.shape[0], self.n_movements, s.shape[1]]).to(self.device)
            
            # Not know if really needed (before there wasn't)
            tmp_s = torch.clone(s) 
            
            # Expand the seed to 32 movements in (1, 32, 15)
            tmp_s = tmp_s.expand(s.shape[0], self.n_movements, s.shape[1])
            
            #Concatenate the noise to each seed and call the generator
            out = self.generator(torch.cat((tmp_s,z), dim=2))
            
            # Expand the seed to 32 movements in (1, 32, 15)
            s = s.expand(s.shape[0], self.n_movements, s.shape[1])
            
            # Effectively compute the n_movements
            out = torch.add(s, out)
            
            # Reshape to (32, 15)
            out = torch.reshape(out, [self.n_movements, out.shape[2]])

            decoded_material = self.decoder(out)

            y = self.simulator(decoded_material)
            loss, y_loss, total_losses, best_mat_idx = self.loss(y, y_desired, decoded_material, self.selection_layer.get_params())

            epsilon_material = 0.005 + 0.99 * decoded_material 
            
            if semloss != None:
                reshaped_mat = epsilon_material[:,:self.n_layer * self.n_mat].reshape(self.n_layer * self.n_mat, epsilon_material.shape[0])
                _, wmc = semloss(probabilities=reshaped_mat, output_wmc_per_sample=True)

                multi_sem_loss = -torch.log(wmc)
                sem_loss = multi_sem_loss.mean()
                loss += sem_loss
        
            loss.backward(retain_graph=True)

            # Scale generated materials back to nanometer range
            unscaler(decoded_material, self.n_layer) 
            
            latent_point = np.copy(out[best_mat_idx].detach().cpu().numpy())
            decoded_point = np.copy(decoded_material[best_mat_idx].detach().cpu().numpy())
            rec_loss = y_loss[best_mat_idx].item()  
            sem_loss = multi_sem_loss[best_mat_idx].item() if semloss != None else "None" # type: ignore
            onehot = evaluate_one_hot(decoded_material[best_mat_idx], n_layer=self.n_layer, n_mat=self.n_mat)
            
            # Log only the best point
            history.append((rec_loss, sem_loss, onehot, total_losses[best_mat_idx].item(), latent_point, decoded_point))

            optimizer.step()

            if LOG:
                print(f"Epoch: {epoch+1}, SRMSE: {np.sqrt((rec_loss * rec_loss) / 2001)}, ")
                real_srmse = evaluate_single_srmse(torch.tensor(decoded_point), x_desired, self.metamat_config)
                print(f"Epoch {epoch+1}, Real SRMSE: {real_srmse}")
                
                #print(f"Selection layer weights: {self.selection_layer.get_params()}")
                #print(f"SOFTMAX Selection layer weights: {torch.softmax(self.selection_layer.get_params(), dim=0)}")
                print()

        return history