import sys, os

from zmq import device

sys.path.append(os.path.abspath("../../")) 
sys.path.append(os.path.abspath("../")) 
sys.path.append(os.path.abspath("../../../")) 

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from src.models.gidnet.selection_layer import SelectionLayer
from src.models.gidnet.generator_Nls import Generator
import numpy as np
import commons.semantic_loss as semloss
from commons.utils import  count_number_hyperbolic_materials
from utils.data import simulate_material

def evaluate_one_hot(decoded_material, n_layer, n_mat):
    if len(decoded_material.shape) < 2:
        decoded_material = decoded_material.unsqueeze(0)

    values = decoded_material[:, :-n_layer].reshape(-1,n_mat).square().sum(dim=1)
    mean = torch.mean(values).item()
    return mean

class GidNet():
    
    def __init__(self, n_layer, n_mat, n_seeds, n_movements, lambda1, onehot_weight, encoder, decoder, simulator, device):
        self.n_layer = n_layer
        self.n_seeds = 1
        self.encoder = encoder
        self.decoder = decoder
        self.simulator = simulator
        self.device = device
        self.n_movements = 1
        self.lambda1 = lambda1
        self.onehot_weight = onehot_weight
        self.best_srmse = float('inf')
        self.best_srmse_onehot = None
        self.best_srmse_one_hot_accuracy = None
        self.best_srmse_s_weights = None
        self.best_srmse_epoch = None
        self.best_srmse_loss = None
        self.local_bests_list = []
        self.latent_space_point_list = []
        self.n_mat = n_mat
        
        self.result: torch.Tensor
        self.seeds: torch.Tensor


    def get_best_srmse(self):
        return self.best_srmse

    def get_result(self):
        return self.best_srmse_onehot

    def get_best_srmse_one_hot_accuracy(self):
        return self.best_srmse_one_hot_accuracy

    def get_best_srmse_s_weights(self):
        return self.best_srmse_s_weights

    def get_best_srmse_epoch(self):
        return self.best_srmse_epoch

    def get_best_srmse_loss(self):
        return self.best_srmse_loss

    def get_local_bests_list(self):
        return self.local_bests_list

    def compute_knn(self, X_train, y_train, y, k):
        dist = torch.norm(y_train - y, dim=1, p=2)
        knn = dist.topk(k, largest=False)
        seeds = X_train[knn.indices]
        
        return seeds
    
    def _init_models(self):
        self.selection_layer = SelectionLayer(self.n_seeds, self.device)
        self.generator = Generator(self.n_layer, self.n_movements, self.device)
        self.best_srmse = float('inf')
        self.best_srmse_onehot = None
        self.best_srmse_one_hot_accuracy = None
        self.best_srmse_s_weights = None
        self.best_srmse_epoch = None
        self.best_srmse_loss = None
        self.local_bests_list = []

        
    def evaluate_one_hot(self, one_hot):
        values = one_hot[:-self.n_layer].reshape(-1,5).square().sum(dim=1)
        mean = torch.mean(values).item()
        return mean
        
        
    def loss(self, y, y_desired, decoded_mat, weights):
        # The y has shape NUM_MOVEMENTS x 2001 (2400 for 5layers) (wave points)
       
        # Obtain the l2-norm for every point in NUM_Movements
        l2_y = torch.norm(y - y_desired, p=2, dim=1)

        # Compute the weights component of the loss function and regularize it
        w_y = torch.square(nn.functional.softmax(weights, dim=0)).sum(dim=0).mul(-1).expand(l2_y.shape[0])

        #one-hot constraint
        oh_c =  1 - decoded_mat[:, :-self.n_layer].reshape(-1, self.n_mat).square().sum(dim=1).reshape(-1, self.n_layer).sum(dim=1)

        # Obtain a NUM_MOVEMENTS x 1 tensor
        total_loss = (l2_y * 1.5) + (w_y * self.lambda1) + (oh_c * self.onehot_weight)

        # take the index of the best point in the total loss
        best_idx = torch.argmin(total_loss)

        return torch.mean(total_loss), l2_y, total_loss, best_idx


    def train(self, x_train, y_train, y_desired, n_epoch, lr, x_desired, quiet=True, semloss=None):
        self._init_models()
        initial_loss = 0
        #scaler(x_train, self.n_layer) if self.n_layer <= 5 else None
        
        self.seeds = self.compute_knn(x_train, y_train, y_desired, self.n_seeds)
        self.point = nn.Parameter(torch.rand(self.encoder(self.seeds).shape,device="cuda:0"),requires_grad=True)
        #self.seeds.data[self.n_seeds-1] = torch.zeros(self.seeds.shape[1])

        losses = []

        self.seeds = self.seeds.to(self.device)
        y_desired = y_desired.to(self.device)
        #unscaler(x_train, self.n_layer) if self.n_layer <= 5 else None

        optimizer = torch.optim.Adam([
                {"params": self.point},
            ],
            lr=lr
        )

        self.seeds = self.encoder(self.seeds)
        self.seeds = torch.reshape(self.seeds, (1,self.seeds.shape[0], self.seeds.shape[1]))

        history = []
        for epoch in range(n_epoch):
            optimizer.zero_grad()
            #s = self.selection_layer(self.seeds)

            #z = torch.randn(size=[s.shape[0], self.n_movements, s.shape[1]]).to(self.device)
         
            #tmp_s = torch.unsqueeze(s, dim=1).expand(s.shape[0], self.n_movements, s.shape[1])

            #out = self.generator(torch.cat((tmp_s,z), dim=2))
            #out = self.generator(tmp_s)
            #out = torch.add(s, out)

            #out = torch.reshape(out, [self.n_movements, out.shape[2]])

            decoded_material = self.decoder(self.point)
            
            count,_ = count_number_hyperbolic_materials(decoded_material,10,7,pattern_len=3)
            
            y = self.simulator(decoded_material)
            
            #loss, y_loss, total_losses, best_mat_idx = self.loss(y, y_desired, decoded_material,self.selection_layer.get_params())
            loss = (y - y_desired).square().mean()
            if epoch == 0:
                initial_loss = loss
            print(f"Epoch: {epoch}, Count: {count}, Loss: {loss}")
            print(self.selection_layer.get_params())
            # Apply an epsilon to material to avoid getting 1.000 (exploding gradients)
            epsilon_material = 0.005 + 0.99 * decoded_material
            

            if semloss != None:
                reshaped_mat = epsilon_material[:,:self.n_layer * self.n_mat].reshape(self.n_layer * self.n_mat, epsilon_material.shape[0])
                _, wmc = semloss(probabilities=reshaped_mat, output_wmc_per_sample=True)

                multi_sem_loss = -torch.log(wmc)
                sem_loss = multi_sem_loss.mean()
                loss += sem_loss

            loss.backward(retain_graph=True)

            #unscaler(decoded_material, self.n_layer, True) if self.n_layer <= 5 else None
            '''
            latent_point = np.copy(out[best_mat_idx].detach().cpu().numpy())
            decoded_point = np.copy(decoded_material[best_mat_idx].detach().cpu().numpy())
            rec_loss = y_loss[best_mat_idx].item() 
            sem_loss = multi_sem_loss[best_mat_idx].item() if semloss != None else "None" # type: ignore
            onehot = evaluate_one_hot(decoded_material[best_mat_idx], n_layer=self.n_layer, n_mat=self.n_mat)
            
            # Log only the best point
            history.append((rec_loss, sem_loss, onehot, total_losses[best_mat_idx].item(), latent_point, decoded_point.tolist()))

            #scaler(decoded_material, self.n_layer, True) if self.n_layer <= 5 else None
            '''
            optimizer.step()   
        print(f"First loss: {initial_loss}")
        return history