import torch
from tqdm import tqdm
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import wandb
import os


from src.dataset import KGEDataset, OgbKGEDataset, OgbKGEDatasetNBF 
from src.config import get_arm_model
from src.criterion import LogLoss
from src.torchkge_evaluation import LinkPredictionEvaluator
from src.config import load_pretrained_model
from src.models import get_nbf_wrapper
DATASET_MAPPING = {
    "OGBLBioKG": {
        "nbf": OgbKGEDatasetNBF,
        "default": OgbKGEDataset
    },
    "default": KGEDataset
}


class Engine:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda:0" if os.environ.get("USE_GPU") else "cpu")
        self.config['device'] = self.device
        print("DEVICE", self.device)
        self.dataset = self._get_dataset()

        self.setup()

    def _get_dataset(self):
        # Get the appropriate dataset class based on the configuration
        dataset_class = DATASET_MAPPING.get(self.config['dataset']['class'], {}).get(
            self.config.get('model_type'), 
            DATASET_MAPPING.get(self.config['dataset']['class'], {}).get('default', KGEDataset)
        )

        return dataset_class(self.config)

    def _get_model(self, number_of_entities, number_of_relations):
        if self.config.get('use_pretrained_model'):
            model = load_pretrained_model(
                self.config,
                number_of_entities,
                number_of_relations,
                self.dataset.kg_train
            )
        else:
            model = get_arm_model(self.config, number_of_entities, number_of_relations, self.dataset.kg_train)
        
        if self.config['model_type'] == "nbf":
            model.nbf_model.cuda(torch.device("cuda"))
        else:
            print("MODEL", model)
            model.to(self.device)
        
        return model

    def setup(self):
        number_of_entities = self.dataset.n_entities
        number_of_relations = self.dataset.n_relations
        number_of_inverse_relations = number_of_relations

        self.model = self._get_model(number_of_entities, number_of_relations)

        self.criterion = LogLoss(self.config['prediction_smoothing'], self.config['label_smoothing'],
                                 number_of_entities, number_of_relations+number_of_inverse_relations)
        self.optimizer = AdamW(self.model.parameters(), lr=self.config['lr'], weight_decay=self.config['weight_decay'])
        self.scheduler = ReduceLROnPlateau(self.optimizer, factor=self.config['factor'], min_lr=1e-6, patience=self.config['lr_patience'], mode="max")

    def train_epoch(self):
        self.model.train()
        epoch_loss = 0.0

        # Create a new dataloader for this epoch
        train_dataloader = self.dataset.get_dataloader(
            batch_size=self.config['batch_size'],
            split='train',
        )

        for batch in tqdm(train_dataloader):
            batch = [b.to(self.device) for b in batch]
            triple = batch[0], batch[1], batch[2]
            self.optimizer.zero_grad()
            predictions = self.model(triple)
            loss = self.criterion.log_loss(predictions=predictions, labels=triple)
            inverse_triple = batch[2], batch[1] + self.dataset.kg_train.n_rel, batch[0]
            inverse_predictions = self.model(inverse_triple)
            inverse_loss = self.criterion.log_loss(predictions=inverse_predictions, labels=inverse_triple)
            loss = loss + inverse_loss
            loss.backward()
            self.optimizer.step()
            epoch_loss += loss.item()
        return epoch_loss

    def train(self):
        if self.config['wandb']:
            wandb.init()

        current_patience = 0
        best_mrr_till_now = 0.0
        for epoch in range(self.config['epochs']):
            epoch_loss = self.train_epoch()
            self.model.eval()
            with torch.no_grad():
                if epoch % 1 == 0:
                    evaluator = LinkPredictionEvaluator(self.config, self.device, self.model, self.dataset, self.config['test_batch_size'], split='valid')
                    metrics = evaluator.get_link_prediction_metrics()
                    self.scheduler.step(metrics['mrr'])
                    metrics['lr'] = self.optimizer.param_groups[0]['lr']
                    metrics['loss'] = epoch_loss
                    metrics['epoch'] = epoch
                    if self.config['wandb']:
                        wandb.log(metrics)
                    else:
                        print(metrics)
                    if metrics['mrr'] < best_mrr_till_now:
                        current_patience += 1
                    else:
                        current_patience = 0
                        best_mrr_till_now = metrics['mrr']
                        print("BEST MRR TILL NOW", best_mrr_till_now)
                        if self.config['save_model']:
                            print("SAVED", self.config['model_path'])
                            os.makedirs(os.path.dirname(self.config['model_path']), exist_ok=True)
                            torch.save(self.model.state_dict(), self.config['model_path'])
                    if current_patience > self.config['max_patience']:
                        print("STOPPED EARLY")
                        if self.config['wandb']:
                            wandb.finish()
                        return
