import os
import torch
import torch.optim as optim
from DataSet.DataLoader import get_dataloader
from Model.Model import Model
from Model.Loss import LossFunction
from Model.Score import ScoreFunction, OTScoreFunction
from Model.optim_utils import Lookahead, Lamb
from utils import aucPerformance, get_logger, F1Performance

class Trainer(object):
    def __init__(self, run: int, model_config: dict):
        self.run = run
        self.device = model_config['device']
        self.learning_rate = model_config['learning_rate']
        self.dataset = model_config['dataset_name']
        self.model_type = model_config['model']
        self.basis_type = model_config['basis_type']
        self.attention_type = model_config['attention_type']
        self.model = Model(model_config).to(self.device)


        self.loss_fuc = LossFunction(model_config).to(self.device)
        self.score_func = ScoreFunction(model_config).to(self.device)
        self.ot_score_func = OTScoreFunction(model_config).to(self.device)
        self.train_loader, self.test_loader = get_dataloader(model_config)
        

        self.save_dir = f'./Save/{self.dataset}/{self.model_type}/{self.basis_type}_{self.attention_type}/'

        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

    def training(self, epochs):

        train_logger = get_logger(self.save_dir + 'train_log.log')

        optimizer = Lamb(
            self.model.parameters(), lr=self.learning_rate, betas=(0.9, 0.999),
            weight_decay=0, eps=1e-6)
        optimizer = Lookahead(optimizer, k=6)

        warmup_epochs = 10
        def lr_lambda(epoch):
            if epoch < warmup_epochs:
                return float(epoch) / float(max(1, warmup_epochs))
            return 0.5 * (1. + torch.cos(torch.tensor(torch.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs))))

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

        self.model.train()
        print("Training Start.")
        min_loss = float('inf')
        for epoch in range(epochs):
            for step, (x_input, y_label) in enumerate(self.train_loader):
                x_input = x_input.to(self.device)
                
                x_pred, masks, feature_map, feature_prototype, attention_map, attention_prototype= self.model(x_input)
                loss, mse, ortho_loss,feature_loss, attention_loss = self.loss_fuc(x_input, x_pred, masks, feature_map, feature_prototype, attention_map, attention_prototype)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            scheduler.step()
            if epoch % 10 == 0:
                info = 'Epoch:[{}]\t loss={:.4f}\t mse={:.4f}\t ortholoss={:4f}\t featureloss={:.4f}\t attentionloss={:.4f}\t'
                train_logger.info(info.format(epoch,loss.cpu(),mse.cpu(),ortho_loss,feature_loss.cpu(), attention_loss.cpu()))

            if loss < min_loss:
                torch.save(self.model, self.save_dir + f'{self.dataset}_model.pth')
                min_loss = loss
        print("Training complete.")
        train_logger.handlers.clear()

    def evaluate(self, mse_rauc, mse_ap, mse_f1):
        print(self.save_dir)
        model = torch.load(self.save_dir + f'{self.dataset}_model.pth', map_location=self.device)
        model.eval()
        mse_score, prototype_mixed_score, test_label = [], [], []
        for step, (x_input, y_label) in enumerate(self.test_loader):
            x_input = x_input.to(self.device)
        
            x_pred, masks, feature_map, feature_prototype, attention_map, attention_prototype= model(x_input)
            
            feature_score = self.ot_score_func(feature_prototype, feature_map, 'feature')
            attention_score = self.ot_score_func(attention_prototype, attention_map, 'attention')
            
            mse_batch = self.score_func(x_input, x_pred)
           
            prototype_score_adjusted = attention_score * (0.001 * torch.mean(mse_batch) / torch.mean(attention_score)) + feature_score * (0.001 * torch.mean(mse_batch) / torch.mean(feature_score))
                
                
            prototype_mixed_batch = mse_batch + prototype_score_adjusted
            prototype_mixed_batch = prototype_mixed_batch.data.cpu()
            prototype_mixed_score.append(prototype_mixed_batch)

            mse_batch = mse_batch.data.cpu()
            mse_score.append(mse_batch)
            test_label.append(y_label)
        mse_score = torch.cat(prototype_mixed_score, axis=0).numpy()
        test_label = torch.cat(test_label, axis=0).numpy()
        mse_rauc[self.run], mse_ap[self.run] = aucPerformance(mse_score, test_label)
        mse_f1[self.run] = F1Performance(mse_score, test_label)


