import torch
from tqdm import tqdm
import numpy as np
import time

from libs.data import get_batch_size
from libs.eval import calcuate_metric_ad
from libs.augmentation import make_attention_based_mask

class supmodel(torch.nn.Module):
    def __init__(self, tasktype, params, device, data_id=None, modelname=None, cat_features=[]):
        
        super(supmodel, self).__init__()
        
        self.cat_features = cat_features
        self.device = device
        self.params = params
        self.tasktype = tasktype
        
        self.batch_size = self.params['batch_size']
        self.aug = self.params['aug']
        self.dataname = self.params['dataname']
        self.fusion_num = self.params['fusion_num']
        self.fusion_method = self.params['fusion_method']
        
        self.early_stopping = EarlyStopping_ad(
            patience=params["early_stopping_rounds"]
        )
        
    
    def fit(self, X_train, y_train, X_val, y_val):

        if self.batch_size == 1:
            batch_size =  get_batch_size(len(X_train))
        else:
            batch_size =  self.batch_size
                        
        optimizer = self.model.make_optimizer()        
        loss_fn = torch.nn.functional.mse_loss
        
        train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
        val_dataset = torch.utils.data.TensorDataset(X_val, y_val)
        del X_train, y_train, X_val, y_val
        
        if len(train_dataset) % batch_size == 1:
            train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
        else:
            train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
        
        optimizer.zero_grad(); optimizer.step()
        
        self.model = self.model.to(self.device)
        
        loss_history = []
        pbar = tqdm(range(1, self.params.get('n_epochs', 0) + 1))
        for epoch in pbar:
            pbar.set_description("EPOCH: %i" %epoch)
            
            for i, (x, y) in enumerate(train_loader):
                self.model.train(); optimizer.zero_grad()
                x =x.to(self.device)
                y =y.to(self.device)
                
                # attention based masking 
                if self.aug=='attention_mask_fusion':
                    with torch.no_grad():
                        out, attention_weights = self.model(x, return_attention=True)
                    mask_x_cont, all_neighbors = make_attention_based_mask(attention_weights, self.params, epoch)
                    x = x * mask_x_cont # masked x                 
                    out = self.model(x)
                else:
                    with torch.no_grad():
                        out, attention_weights = self.model(x, return_attention=True)
                    _, all_neighbors = make_attention_based_mask(attention_weights, self.params, epoch)         
                    out = self.model(x)

                # neighbors-based reconstruction 
                if self.fusion_num > 0:
                    con_outs = self.mlp_encoder2(out)
                    all_outs = torch.stack(con_outs, dim=1)
                    B, F, D = all_outs.shape
            
                    if self.fusion_method == "concat":
                        results = []
                        for ni in range(B):
                            target = all_outs[ni].unsqueeze(0)
                            neighbors_idx = torch.tensor(all_neighbors[ni], dtype=torch.long, device=all_outs.device)
                            neighbors = all_outs[neighbors_idx]
                            selected = torch.cat([target, neighbors], dim=0)
                            fused = selected.permute(1, 0, 2).reshape(F, -1)
                            results.append(fused)
                        final_result = torch.stack(results, dim=0)

                    elif self.fusion_method == "average_concat":
                        results = []                 
                        for ni in range(B):
                            target = all_outs[ni].unsqueeze(0)
                            neighbors_idx = torch.tensor(all_neighbors[ni], dtype=torch.long, device=self.device) 
                            averaged_selected = all_outs[neighbors_idx].mean(dim=0, keepdim=True)
                            target_selected = torch.concat([target, averaged_selected], axis=0)
                            results.append(target_selected.permute(1, 0, 2).reshape(F, -1))
                        final_result = torch.stack(results, dim=0)
                        
                    elif self.fusion_method == "average":
                        results = []
                        for ni in range(B):
                            target = all_outs[ni].unsqueeze(0)
                            averaged_selected = all_outs[all_neighbors[ni]]
                            target_selected = torch.concat([target, averaged_selected], axis=0).mean(dim=0, keepdim=True)
                            results.append(target_selected.permute(1, 0, 2).reshape(F, -1))
                        final_result = torch.stack(results, dim=0)
                        
                    final_cont_outs = self.mlp_decoder2(final_result)
                    final_cont_outs = torch.cat(final_cont_outs, dim=1)
                else: # without neighbors 
                    final_cont_outs = self.mlp2(out).squeeze(-1)

                loss = loss_fn(final_cont_outs, x)
                loss_history.append(loss.item())
                
                loss.backward()
                optimizer.step() 
                
                pbar.set_postfix_str(f'dataname: {self.dataname}, Tr loss: {loss:.5f}')
            self.model.eval()
            val_loss = 0.0
            recon_errors = []
            y_labels = []
            
            # evaluation 
            with torch.no_grad():
                for x_val, y_val in val_loader:                    
                    x_val =x_val.to(self.device)
                    y_val =y_val.to(self.device)
                    
                    # attention based masking 
                    if self.aug=='attention_mask_fusion':
                        with torch.no_grad():
                            _, attention_weights = self.model(x_val, return_attention=True)
                        mask_x_val, all_neighbors = make_attention_based_mask(attention_weights, self.params, epoch)
                        x_val = x_val * mask_x_val
                        out_val = self.model(x_val)
                    else:
                        with torch.no_grad():
                            out_val, attention_weights = self.model(x_val, return_attention=True)
                        _, all_neighbors = make_attention_based_mask(attention_weights, self.params, epoch)                        
                        out_val = self.model(x_val)

                    # neighbors-based reconstruction 
                    if self.fusion_num > 0:
                        con_outs_val = self.mlp_encoder2(out_val)
                        all_outs_val = torch.stack(con_outs_val, dim=1)
                        B, F, D = all_outs_val.shape
                        
                        if self.fusion_method == "concat":
                            results_val = [] 
                            for ni in range(B):
                                target = all_outs_val[ni].unsqueeze(0)
                                neighbors_idx = torch.tensor(all_neighbors[ni], dtype=torch.long, device=self.device)
                                neighbors = all_outs_val[neighbors_idx]
                                selected = torch.cat([target, neighbors], dim=0)
                                fused = selected.permute(1, 0, 2).reshape(F, -1)
                                results_val.append(fused)
                            final_result_val = torch.stack(results_val, dim=0)

                        elif self.fusion_method == "average_concat":
                            results_val = []                 
                            for ni in range(B):
                                target = all_outs_val[ni].unsqueeze(0)
                                neighbors_idx = torch.tensor(all_neighbors[ni], dtype=torch.long, device=self.device) 
                                averaged_selected = all_outs_val[neighbors_idx].mean(dim=0, keepdim=True)
                                target_selected = torch.concat([target, averaged_selected], axis=0)
                                results_val.append(target_selected.permute(1, 0, 2).reshape(F, -1))
                            final_result_val = torch.stack(results_val, dim=0)
                            
                        elif self.fusion_method == "average":
                            results_val = []
                            for ni in range(B):
                                target = all_outs_val[ni].unsqueeze(0)
                                averaged_selected = all_outs_val[all_neighbors[ni]]
                                target_selected = torch.concat([target, averaged_selected], axis=0).mean(dim=0, keepdim=True)
                                results_val.append(target_selected.permute(1, 0, 2).reshape(F, -1))
                            final_result_val = torch.stack(results_val, dim=0)
                        
                        # Decoder
                        final_outs_val = self.mlp_decoder2(final_result_val)
                        final_outs_val = torch.cat(final_outs_val, dim=1)
                    else:
                        final_outs_val = self.mlp2(out_val).squeeze(-1)
                    
                    error = torch.mean((final_outs_val - x_val)**2, dim=1)
                    recon_errors.append(error.cpu())
                    y_labels.append(y_val.cpu())
                recon_errors = torch.cat(recon_errors).numpy()
                y_labels = torch.cat(y_labels).numpy()
            
            mse_auc, mse_ap, mse_f1 = calcuate_metric_ad(y_labels, recon_errors)
            print('TEST AUC-PR: %.3f, TEST AUROC: %.3f, TEST F1: %.3f' % (mse_ap, mse_auc, mse_f1))

            is_best_epoch = self.early_stopping.on_epoch_end(epoch, {"auroc": mse_auc})
            if is_best_epoch:
                self.save_model(self.params['save_model_path'])
                print("Save best model")
                
            
            if self.early_stopping.should_stop:
                print(f"Early stopping triggered after {epoch} epochs.")
                break

    
    def predict(self, X_test, cat_features=[]):
        self.model.eval()
        recon_errors = []
        y_labels = []
        
        if self.batch_size == 1:
            batch_size =  get_batch_size(len(X_test))
        else:
            batch_size =  self.batch_size
            
        eval_time = time.time()
        with torch.no_grad():
            for i in range(0, X_test.shape[0], batch_size): 
                
                x = X_test[i:i+batch_size]
                if x.shape[0] < self.fusion_num:
                    continue
                x = x.to(self.device)
                
                # attention based masking 
                if self.aug=='attention_mask':
                    with torch.no_grad():
                        _, attention_weights = self.model(x, return_attention=True)
                    mask_x, all_neighbors = make_attention_based_mask(attention_weights, self.params, i)
                    x = x * mask_x # masked x 
                    out_val = self.model(x)
                else:
                    with torch.no_grad():
                        _, attention_weights = self.model(x, return_attention=True)
                    _, all_neighbors = make_attention_based_mask(attention_weights, self.params, i)   
                    out_val = self.model(x)

                # neighbor informed reconstruction 
                if self.fusion_num > 0:
                    con_outs_val = self.mlp_encoder2(out_val)
                    all_outs_val = torch.stack(con_outs_val, dim=1)
                    B, F, D = all_outs_val.shape
                    
                    if self.fusion_method == "concat":
                        results_val = [] 
                        for ni in range(B):
                            target = all_outs_val[ni].unsqueeze(0)
                            neighbors_idx = torch.tensor(all_neighbors[ni], dtype=torch.long, device=self.device)
                            neighbors = all_outs_val[neighbors_idx]
                            selected = torch.cat([target, neighbors], dim=0)
                            fused = selected.permute(1, 0, 2).reshape(F, -1)
                            results_val.append(fused)
                        final_result_val = torch.stack(results_val, dim=0)
                        
                    elif self.fusion_method == "average_concat":
                        results_val = []                 
                        for ni in range(B):
                            target = all_outs_val[ni].unsqueeze(0)
                            neighbors_idx = torch.tensor(all_neighbors[ni], dtype=torch.long, device=self.device) 
                            averaged_selected = all_outs_val[neighbors_idx].mean(dim=0, keepdim=True)
                            target_selected = torch.concat([target, averaged_selected], axis=0)
                            results_val.append(target_selected.permute(1, 0, 2).reshape(F, -1))
                        final_result_val = torch.stack(results_val, dim=0)
                        
                    elif self.fusion_method == "average":
                        results_val = []
                        for ni in range(B):
                            target = all_outs_val[ni].unsqueeze(0)
                            averaged_selected = all_outs_val[all_neighbors[ni]]
                            target_selected = torch.concat([target, averaged_selected], axis=0).mean(dim=0, keepdim=True)
                            results_val.append(target_selected.permute(1, 0, 2).reshape(F, -1))
                        final_result_val = torch.stack(results_val, dim=0)
                    
                    # Decoder       
                    final_outs_val = self.mlp_decoder2(final_result_val)
                    final_outs_val = torch.cat(final_outs_val, dim=1)
                else:
                    final_outs_val = self.mlp2(out_val).squeeze(-1)
                
                error = torch.mean((final_outs_val - x)**2, dim=1)
                recon_errors.append(error.cpu())

            recon_errors = torch.cat(recon_errors).numpy()
            return recon_errors


class EarlyStopping_ad:
    def __init__(self, early_stopping_metric="auroc", patience=500):
        self.early_stopping_metric = early_stopping_metric
        self.patience = patience
        self.best_value = None
        self.patience_counter = 0
        self.should_stop = False

    def on_epoch_end(self, epoch, logs=None):
        current_value = logs.get(self.early_stopping_metric)
        if current_value is None:
            return

        is_best = False        
        if self.best_value is None or current_value > self.best_value:
            self.best_value = current_value
            self.patience_counter = 0
            is_best = True
        else:
            self.patience_counter += 1

        if self.patience_counter >= self.patience:
            self.should_stop = True
        
        return is_best