import torch
import lightning as L
import torchmetrics
import yaml
from .similarities import Similarities_manager
from .augments import CutMix_with_prob


class Light_Net(L.LightningModule):
    def __init__(self, network, loss_fn, optimizer, scheduler, model_info, dataset, wn_data_path, glove_data_path,
                 glove_embeddings_data_file, grapher):
        super().__init__()
        self.network = network
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.grapher = grapher
        self.model_info = model_info
        self.scheduler = scheduler
        self.num_of_classes = dataset.label_order[0].shape[0]

        self.sim_manager = Similarities_manager(dataset, wn_data_path, glove_data_path, glove_embeddings_data_file)
        self.cutmix_augmenter = CutMix_with_prob(num_classes=dataset.label_order[0].shape[0],
                                                 p=model_info['AUG']['CUTMIX_P'])

        self.cosine_value = 0
        self.structural_value = 0
        self.confusion_value = 0
        self.weight_sim_score = 0
        self.mse_value = 0
        self.mae_value = 0
        self.cosine_value_glove = 0
        self.mse_value_glove = 0
        self.mae_value_glove = 0
        self.structural_value_glove = 0
        self.weight_min_sim_score = 0
        self.weight_max_sim_score = 0
        self.dm_value_weights = 0
        self.dm_value_wn = 0
        self.dm_value_glove = 0

        self.weight_confusion_product_sum = 0
        self.cf_value_weights = 0
        self.cf_value_wordnet = 0
        self.cf_value_glove = 0

        self.eval_results = {'preds': [], 'flags': []}

        self.loss_train = torch.Tensor([0])
        self.loss_test = torch.Tensor([0])
        self.acc_train = torchmetrics.classification.Accuracy(task="multiclass", num_classes=self.num_of_classes)
        self.acc_test = torchmetrics.classification.Accuracy(task="multiclass", num_classes=self.num_of_classes)
        self.rank_5_train = torchmetrics.classification.Accuracy(task="multiclass", num_classes=self.num_of_classes)
        self.rank_5_test = torchmetrics.classification.Accuracy(task="multiclass", num_classes=self.num_of_classes)

    def get_top_5(self, logits, lbl):
        top_logits = torch.topk(logits, 5, dim=1).indices
        matches = torch.Tensor([torch.argmax(lbl, dim=1)[i] in top_logits[i] for i in range(top_logits.shape[0])]).to(
            'cuda')
        match_labels = torch.where(matches.unsqueeze(dim=1) == 1, torch.argmax(lbl, dim=1).unsqueeze(dim=1),
                                   top_logits[:, 0].unsqueeze(dim=1))
        return match_labels.squeeze()

    def configure_optimizers(self):
        return [self.optimizer], [{"scheduler": self.scheduler, "monitor": 'Loss/test', "interval": "epoch"}]

    def lr_scheduler_step(self, scheduler, metric):
        try:
            scheduler.step(metrics=self.loss_test)  # as only ROP uses additional metrics in step
        except TypeError:
            scheduler.step()

    def on_train_start(self):
        """save model info in yml file"""
        try:
            yaml.dump(self.model_info, open(f'{self.logger.log_dir}/model_info.yml', 'w'))
        except AttributeError:
            print("[INFO] No logging - model_info not saved")

    def network_step(self, batch):
        im, lbl, meta = batch
        logits = self.network(im)
        if self.model_info['AUG']['CUTMIX']:
            targets = lbl  #.softmax(dim=1)
        else:
            targets = torch.argmax(lbl, dim=1)
        loss = self.loss_fn(logits, targets)
        preds = torch.nn.functional.one_hot(torch.argmax(torch.nn.functional.softmax(logits, dim=1), dim=1), num_classes=lbl.shape[-1])
        preds_5 = self.get_top_5(logits, lbl)
        return loss, preds, preds_5, lbl

    def training_step(self, batch, batch_idx):
        # add cutmix for training only
        if self.model_info['AUG']['CUTMIX']:
            batch = self.cutmix_augmenter.transform(batch)
        loss, preds, preds_5, flag = self.network_step(batch)
        self.loss_train = loss
        self.acc_train(torch.argmax(preds, dim=1), torch.argmax(flag, dim=1))
        self.rank_5_train(preds_5, torch.argmax(flag, dim=1))

        self.log('Acc/train/step', self.acc_train, sync_dist=True)
        self.log("Loss/train/step", self.loss_train, sync_dist=True)
        return self.loss_train

    def validation_step(self,  batch, batch_idx):
        loss, preds, preds_5, flag = self.network_step(batch)
        self.loss_test = loss
        self.acc_test(torch.argmax(preds, dim=1), torch.argmax(flag, dim=1))
        self.rank_5_test(preds_5, torch.argmax(flag, dim=1))

        self.eval_results['preds'].append(preds)
        self.eval_results['flags'].append(flag)

        return self.loss_test

    def on_train_end(self):
        matrices = self.sim_manager.return_matrices()
        self.grapher.save_matrices(matrices, binary_output=True)

    def on_train_epoch_end(self):
        self.grapher.add_data(train_data=[self.loss_train.item(), self.acc_train.compute().item()],
                              test_data=[self.loss_test.item(), self.acc_test.compute().item()],
                              lr=self.optimizer.param_groups[0]['lr'],
                              cosine=self.cosine_value,
                              weights=self.weight_sim_score,
                              dm_value_net=self.dm_value_weights,
                              dm_value_wordnet=self.dm_value_wn,
                              dm_value_glove=self.dm_value_glove,
                              weights_min=self.weight_min_sim_score,
                              weights_max=self.weight_max_sim_score,
                              cosine_glove=self.cosine_value_glove,
                              mse_glove=self.mse_value_glove,
                              mae_glove=self.mae_value_glove,
                              structural_glove=self.structural_value_glove,
                              structural=self.structural_value,
                              mse=self.mse_value,
                              mae=self.mae_value,
                              confusion_weights_prod_sum=self.weight_confusion_product_sum,
                              confusion_weights_cos=self.cf_value_weights,
                              confusion_wordnet=self.cf_value_wordnet,
                              confusion_glove=self.cf_value_glove)

        self.log('Acc/train', self.acc_train.compute().item(), sync_dist=True)
        self.log('Acc/train_5', self.rank_5_train.compute().item(), sync_dist=True)
        self.log('Loss/train', self.loss_train, sync_dist=True)
        print(f'\n[INFO] Training -- Loss: {self.loss_train} Accuracy: {self.acc_train.compute().item()}\n')

        matrices = self.sim_manager.return_matrices()
        self.grapher.save_matrices(matrices, binary_output=True)

    def on_validation_epoch_end(self):
        self.log('Acc/test', self.acc_test.compute().item(), sync_dist=True)
        self.log('Acc/test_5', self.rank_5_test.compute().item(), sync_dist=True)
        self.log('Loss/test', self.loss_test, sync_dist=True)
        self.log('LR', self.optimizer.param_groups[0]['lr'], sync_dist=True)

        # compute wordnet similarity
        self.cosine_value, self.structural_value, self.weight_sim_score, self.mse_value, self.mae_value, \
        self.weight_min_sim_score, self.weight_max_sim_score, self.cosine_value_glove, self.structural_value_glove, \
        self.mse_value_glove, self.mae_value_glove = self.sim_manager.get_semantic_sim(self.network)

        # compute mean from CM diag
        self.confusion_value = self.sim_manager.confusion_mean(y_true=self.eval_results['flags'],
                                                               preds=self.eval_results['preds'],
                                                               num_classes=self.num_of_classes)
        # compute DM value
        self.dm_value_weights = self.sim_manager.get_dm_metric(model=self.network, y_true=self.eval_results['flags'],
                                                               y_pred=self.eval_results['preds'],
                                                               num_of_classes=self.num_of_classes, CSM_type='N')

        self.dm_value_glove = self.sim_manager.get_dm_metric(model=self.network, y_true=self.eval_results['flags'],
                                                             y_pred=self.eval_results['preds'],
                                                             num_of_classes=self.num_of_classes, CSM_type='G')

        self.dm_value_wn = self.sim_manager.get_dm_metric(model=self.network, y_true=self.eval_results['flags'],
                                                          y_pred=self.eval_results['preds'],
                                                          num_of_classes=self.num_of_classes, CSM_type='WN')

        # compute confusion matrix - based params
        self.weight_confusion_product_sum = self.sim_manager.get_cf_wg_prod_sum()  # sum of product
        self.cf_value_weights = self.sim_manager.get_confusion_metric(model=self.network, CSM_type='N')
        self.cf_value_glove = self.sim_manager.get_confusion_metric(model=self.network, CSM_type='G')
        self.cf_value_wordnet = self.sim_manager.get_confusion_metric(model=self.network, CSM_type='WN')

        self.clear_dict(self.eval_results)

        self.log('Similarity/Cosine_WordNet', self.cosine_value, sync_dist=True)
        self.log('Similarity/Structural_WordNet', self.structural_value, sync_dist=True)
        self.log('Similarity/MSE_WordNet', self.mse_value, sync_dist=True)
        self.log('Similarity/MAE_WordNet', self.mae_value, sync_dist=True)
        self.log('Similarity/Cosine_GLOVE', self.cosine_value_glove, sync_dist=True)
        self.log('Similarity/Structural_GLOVE', self.structural_value_glove, sync_dist=True)
        self.log('Similarity/MSE_GLOVE', self.mse_value_glove, sync_dist=True)
        self.log('Similarity/MAE_GLOVE', self.mae_value_glove, sync_dist=True)
        self.log('Similarity/Weights', self.weight_sim_score, sync_dist=True)
        self.log('Similarity/Weights_MIN', self.weight_min_sim_score, sync_dist=True)
        self.log('Similarity/Weights_MAX', self.weight_max_sim_score, sync_dist=True)
        self.log('Similarity/DM_value_WordNet', self.dm_value_wn, sync_dist=True)
        self.log('Similarity/DM_value_GLOVE', self.dm_value_glove, sync_dist=True)
        self.log('Similarity/DM_value_Network', self.dm_value_weights, sync_dist=True)
        self.log('Similarity/Confusion_diag', self.confusion_value, sync_dist=True)
        self.log('Similarity/Confusion_Weights_sum', self.weight_confusion_product_sum, sync_dist=True)
        self.log('Similarity/Confusion_Weights_cos', self.cf_value_weights, sync_dist=True)
        self.log('Similarity/Confusion_Weights_GLOVE', self.cf_value_glove, sync_dist=True)
        self.log('Similarity/Confusion_Weights_WordNet', self.cf_value_wordnet, sync_dist=True)

        print(f'\n[INFO] Testing -- Loss: {self.loss_test}, Accuracy: {self.acc_test.compute().item()}, '
              f'WordNet_Cos_sim: {self.cosine_value}, '
              f'WordNet_Structural_sim: {self.structural_value}, '
              f'WordNet_MSE: {self.mse_value}, '
              f'WordNet_MAE: {self.mae_value}, '
              f'WordNet_DM: {self.dm_value_wn}, \n'
              f'Weights_sim: {self.weight_sim_score}, '
              f'Weights_min_sim: {self.weight_min_sim_score}, '
              f'Weights_max_sim: {self.weight_max_sim_score}, '
              f'Net_DM: {self.dm_value_weights},\n'
              f'Confusion_diag: {self.confusion_value}, '
              f'GLOVE_Cos_sim: {self.cosine_value_glove}, '
              f'GLOVE_Structural_sim: {self.structural_value_glove}, '
              f'GLOVE_MSE : {self.mse_value_glove}, '
              f'GLOVE_MAE: {self.mae_value_glove}, '
              f'GLOVE_DM: {self.dm_value_glove}')

    def clear_dict(self, dictionary):
        for key in dictionary.keys():
            dictionary[key] = []
