# -*- coding: utf-8 -*-
import logging
import math
import os

from torch import per_tensor_affine
from torch.utils import data

from models.model_factory import build_model
from utils import trainer_utils
from utils.metric_visualizer import AccuracyVisualizer, LossVisualizer
from utils.metrics import Accuracy, GroupWiseAccuracy
from torch.optim import *
import json
from utils.trainer_utils import create_optimizer
from torch.nn import *
from utils.losses import *
import numpy as np
from torch.utils.data import Dataset, Subset

class ERMTrainer(object):
    def __init__(self, option):
        self.option = option
        self._build_model()
        self._build_optimizer()
        self.accuracy_visualizer = AccuracyVisualizer(self.option.expt_dir)
        self.loss_visualizer = LossVisualizer(self.option.expt_dir)
        self.max_dataset_ixs = {'Train':49998, 'Val':5000, 'Test': 5000}
        self.metrics = {}  # dictionary: epoch -> data split (e.g., val, test) -> metrics

    def _build_model(self):
        """
        Constructs the model using the model factory
        :return:
        """
        self.model = build_model(
            self.option,
            self.option.model_name,
            in_dims=self.option.in_dims,
            hid_dims=self.option.hid_dims)
        logging.getLogger().info(f"Model {self.model}")
        self.loss = eval(self.option.loss_type)(reduction='none')

        if self.option.cuda:
            self.model.cuda()
            self.loss.cuda()

    def _build_optimizer(self, lr=None, weight_decay=None, named_params=None):
        if lr is None:
            lr = self.option.lr
        if weight_decay is None:
            weight_decay = self.option.weight_decay
        if named_params is None:
            named_params = self.model.named_parameters()
        self.optim = create_optimizer(self.option.optimizer_name,
                                      named_params=named_params,
                                      lr=lr,
                                      weight_decay=weight_decay,
                                      momentum=self.option.momentum,
                                      freeze_layers=self.option.freeze_layers,
                                      custom_lr_config=self.option.custom_lr_config)

    def _initialization(self):
        if self.option.load_checkpoint is not None:
            self.load(self.option.load_checkpoint)
            logging.getLogger().info(f"Loaded from {self.option.load_checkpoint}")

    def _mode_setting(self, is_train):
        self.is_train = is_train
        self.model.train(is_train)

    def prepare_batch(self, batch):
        batch['x'] = batch['x'].cuda()
        batch['dataset_ix'] = torch.LongTensor(batch['dataset_ix'])
        # batch['y'] = batch['y'].long()
        # if 'weight' in batch:
        #     batch['weight'] = torch.LongTensor(batch['weight']).cuda()
        # else:
        #     batch['weight'] = None
        # if 'group_ix' in batch:
        #     batch['group_ix'] = torch.LongTensor(batch['group_ix']).cuda()
        return batch

    def compute_loss(self, out, labels):
        batch_losses = self.loss(out['logits'], torch.squeeze(labels))
        return batch_losses

    def forward_model(self, model, batch, model_in=None):
        if model_in is not None:
            return model(model_in)
        return model(batch['x'])

    def before_train(self, train_loader, test_loaders, test_load_checkpoint=True):
        logging.getLogger().info("Beginning the training process...")
        self._initialization()
        if test_load_checkpoint and self.option.load_checkpoint is not None:
            logging.getLogger().info("Evaluating immediately after loading the checkpoint")
            self._after_one_epoch(-1, test_loaders, force_test=True)
        self._mode_setting(is_train=True)

    def train(self, train_loader, test_loaders=None, unbalanced_train_loader=None):
        self.before_train(train_loader, test_loaders)
        start_epoch = 1
        for epoch in range(start_epoch, self.option.epochs + 1):
            self._train_epoch(epoch, train_loader)
            self._after_one_epoch(epoch, test_loaders)
        self.after_all_epochs()

    def _train_epoch(self, epoch, data_loader):
        self._mode_setting(is_train=True)
        for i, batch in enumerate(data_loader):
            batch = self.prepare_batch(batch)
            self.optim.zero_grad()
            out = self.forward_model(self.model, batch)
            loss_pred = self.loss(out['logits'], torch.squeeze(batch['y'])).mean()
            weight_norm = torch.tensor(0.).cuda()
            for w in self.model.parameters():
                weight_norm += w.norm().pow(2)
            loss_pred += self.option.l2_reg_weight * weight_norm
            loss_pred.backward(retain_graph=True)
            self.optim.step()
            self.update_generalization_metrics('Train', batch, loss_pred)

        self.optim.zero_grad()
        self._after_train_epoch(epoch)

    def _after_train_epoch(self, epoch, split='Train'):
        self.loss_visualizer.log(epoch, split)
        self.loss_visualizer.accumulate_plot_and_reset(epoch)

    def test(self, epoch, data_key, data_loader, model=None, model_key="Main"):
        self.test_default(epoch, data_key, data_loader, model, model_key)

    def get_keys_to_save(self):
        # Override this method to persist variables in the trainer instance
        return ['model', 'optim', 'metrics']

    def get_current_state(self):
        keys = self.get_keys_to_save()
        save_state = {}
        for key in keys:
            if hasattr(self, key) and getattr(self, key) is not None:
                attr = getattr(self, key)
                if hasattr(attr, 'state_dict'):
                    save_state[key] = attr.state_dict()
                else:
                    save_state[key] = attr
        return save_state

    def save(self, save_file):
        save_dir = trainer_utils.get_dir(save_file)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_state = self.get_current_state()
        torch.save(save_state, save_file)
        logging.getLogger().info(f"Saved to {save_file}")

    def load(self, ckpt_path):
        save_state = torch.load(ckpt_path)
        keys = self.get_keys_to_save()
        for key in keys:
            if hasattr(self, key) and getattr(self, key) is not None and key in save_state:
                prop = getattr(self, key)
                if hasattr(prop, 'load_state_dict'):
                    prop.load_state_dict(save_state[key])
                else:
                    setattr(self, key, save_state[key])

        logging.getLogger().info(f"Loaded from {ckpt_path}")

    @staticmethod
    def _weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
        elif classname.find('BatchNorm') != -1:
            m.weight.data.fill_(1.0)
            m.bias.data.zero_()

    # def update_generalization_metrics(self, split, epoch, batch, logits, losses, model_out=None, compute_grads=True):
    #     # Cross-Entropy
    #     self.loss_visualizer.update(split, 'Running Loss', losses.mean().item())
    #     # if self.option.enable_groupwise_metrics:
    #     self.update_groupwise_values(split, 'Running Loss', losses, batch)
    def update_generalization_metrics(self, split, batch, losses):
        # Cross-Entropy
        self.loss_visualizer.update(split, 'Running Loss', losses.mean().item())
        # if self.option.enable_groupwise_metrics:
        # self.update_groupwise_values(split, 'Running Loss', losses, batch)

    def _after_one_epoch(self, epoch, test_loaders, force_test=False):
        if epoch % self.option.save_model_every == 0:
            self.save(os.path.join(self.option.expt_dir, f'ckpt_epoch_{epoch}.pt'))

        if (self.option.test_epochs is not None and epoch in self.option.test_epochs) or (
                self.option.test_every is not None and epoch % self.option.test_every == 0) or force_test:
            for test_key in test_loaders:
                self.test(epoch, test_key, test_loaders[test_key])

        metrics_fname = os.path.join(self.option.expt_dir, 'metrics.json')
        json.dump(self.metrics, open(metrics_fname, 'w'), indent=4, sort_keys=True)

    def after_all_epochs(self):
        """
        Saves all the metrics
        :return:
        """
        metrics_fname = os.path.join(self.option.expt_dir, 'metrics.json')
        json.dump(self.metrics, open(metrics_fname, 'w'), indent=4, sort_keys=True)

    def gather_gt_scores(self, logits, y):
        return logits.gather(1, y.view(-1, 1))

    def test_default(self, epoch, data_key, data_loader, model=None, model_key="Main"):
        logging.getLogger().info(f"\nEpoch {epoch}: Testing with data split: {data_key} model: {model_key}")
        if model is None:
            model = self.model
        self._mode_setting(is_train=False)

        ################################################################################################################
        # Initialize the metrics holders
        ################################################################################################################
        accuracy_metric = Accuracy(self.option.num_classes)
        losses = torch.ones(self.max_dataset_ixs[data_key]).float() * -1000
        logits = torch.ones(self.max_dataset_ixs[data_key], self.option.num_classes).float() * -1000
        chart_name = f'{data_key}_{model_key}'

        ################################################################################################################
        # Now go through the data items, while computing the metrics and storing the predictions
        ################################################################################################################
        pred_gt_ys = np.array([])
        data_ixs = np.array([])
        all_logits = np.array([])
        for i, batch in enumerate(data_loader):
            # Do a forward pass
            batch = self.prepare_batch(batch)
            labels = batch['y']
            out = self.forward_model(model, batch)
            batch_losses = self.compute_loss(out, torch.squeeze(labels))
            pred_ys = (torch.sigmoid(out['logits']).detach().cpu().numpy() > 0.5).astype(int)
            gt_ys = batch['y'].long().squeeze().cpu().numpy()
            batch_ixs = batch['dataset_ix'].squeeze().cpu().numpy()
            batch_logits = out['logits'].detach().cpu().numpy()
            pred_gt_ys = np.append(pred_gt_ys, gt_ys)
            data_ixs = np.append(data_ixs, batch_ixs)
            all_logits = np.append(all_logits, batch_logits)
            accuracy_metric.update(pred_ys, gt_ys)

            ################################################################################################################
            # Store the results
            ################################################################################################################
            # logits[batch['dataset_ix']] = out['logits'].detach().cpu()
            losses[batch['dataset_ix']] = batch_losses.detach().cpu()
            self.loss_visualizer.update(chart_name, f'{model_key} Loss', batch_losses.detach().mean().item())
        
        curr_metric_entry = {}

        ################################################################################################################
        # Update unnormalized, per class and per group accuracies
        ################################################################################################################
        self.accuracy_visualizer.update(chart_name, f'{model_key} Accuracy', accuracy_metric.get_accuracy() * 100)
        self.accuracy_visualizer.update(chart_name, f'{model_key} MPA',
                                        accuracy_metric.get_mean_per_class_accuracy() * 100)
        self.accuracy_visualizer.log(epoch, chart_name)
        self.loss_visualizer.log(epoch, chart_name)
        save_file = os.path.join(self.option.expt_dir, f'preds_{chart_name}_epoch_{epoch}.pt')

        ################################################################################################################
        # Save the predictions
        ################################################################################################################
        if self.option.save_predictions_every is not None and epoch % self.option.save_predictions_every == 0:
            torch.save({
                'logits': all_logits,
                'losses': losses,
                'accuracy_metrics': self.accuracy_visualizer.metrics,
                'y': pred_gt_ys,
                'dataset_ix': data_ixs
            }, save_file)
            logging.getLogger().info(f"Saved to {save_file}")
        elif epoch == (self.option.epochs - 1) :
            save_file = os.path.join(self.option.expt_dir, f'preds_{chart_name}_epoch_{epoch}.json')
            torch.save({
                'logits': all_logits,
                'losses': losses,
                'accuracy_metrics': self.accuracy_visualizer.metrics,
                'y': pred_gt_ys,
                'dataset_ix': data_ixs
            }, save_file)
            logging.getLogger().info(f"Saved to {save_file}")

        ################################################################################################################
        # Plot everything
        ################################################################################################################
        self.accuracy_visualizer.accumulate_plot_and_reset(epoch)

        ################################################################################################################
        # Update metrics
        ################################################################################################################
        if model_key not in self.metrics:
            # logging.getLogger().info(f"{model_key} not found in metrics")
            self.metrics[model_key] = {}
        if epoch not in self.metrics[model_key]:
            self.metrics[model_key][epoch] = {}
        curr_metric_entry['accuracy'] = accuracy_metric.get_accuracy() * 100
        curr_metric_entry['MPA'] = accuracy_metric.get_mean_per_class_accuracy() * 100

        self.metrics[model_key][epoch][data_key] = curr_metric_entry
