import torch
import torch.nn.functional as F
from metrics import compute_metrics
from MetricsRecorder import MetricsRecorder

class Trainer(object):
    def __init__(self, classifier=None, optimizer=None, scheduler=None,
                train_loader=None, val_loader=None, device=None, compute_metrics=None, wandb_run=None, args=None):

        self.classifier = classifier
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.compute_metrics = compute_metrics
        self.run = wandb_run
        self.args = args
        self.mr = MetricsRecorder(hyperparams=vars(args))

        if self.args.log:
            import wandb

    def train(self):
        for e in range(self.args.epochs):
            # train for one epoch
            stop = self.step()

            if self.args.do_eval and (e % self.args.eval_step == 0):
                val_metrics = self.eval()
                if self.args.log:
                    self.log(val_metrics, epoch=e)
                    # val_metrics["epoch"] = e + self.args.epoch_offset
                    # if self.args.scheduler:
                    #     val_metrics["lr"] =  self.scheduler.get_lr()
                    # self.run.log(val_metrics)

            if stop:
                break

    def step(self):
        
        # switch to train mode
        self.classifier.train()

        # init saturation counter and indicator
        sat=False
        j=0

        for i, (inputs, labels) in enumerate(self.train_loader):

            inputs, labels = inputs.to(self.device), labels.to(self.device)
            # print(labels.shape)

            # compute scores from images
            scores = self.classifier(inputs)
            if self.args.clamping:
                scores = torch.clamp(scores, min=self.args.min_clamp, max=self.args.max_clamp)

            # compute both losses and the the difference of the log partitions
            ploss, bloss, log_diff = self.classifier.loss(y=labels, scores=scores, mp=True)

            # compute projected outputs
            outputs = self.classifier.predict().to(self.device)

            # compute & log metrics
            if self.args.log_train:
                if self.args.record_leaves_acc:
                    leaves=self.classifier.hexL.hexg.leaves_mask
                else:
                    leaves=None
                train_metrics = self.compute_metrics(scores, outputs, labels, bloss, ploss, log_diff, leaves=leaves, prefix="train", device=self.device)
                self.run.log(train_metrics)

            elif self.args.saturation_callout:
                train_metrics = {}
                train_metrics["train/max_score"] = torch.max(scores).item()
                train_metrics["train/min_score"] = torch.min(scores).item()


            if self.args.saturation_callout:
                if train_metrics["train/max_score"]==self.args.max_clamp or train_metrics["train/min_score"]==self.args.min_clamp:
                    j+=1
                    if j>20:
                        sat=True
                        break
                else:
                    j=0

            # compute gradient and do optimizer step
            self.optimizer.zero_grad()
            ploss.backward()

            # Run the optimizer
            self.optimizer.step()

            if self.args.scheduler:
                self.scheduler.step()

            if self.args.log_batch and (i % self.args.log_batch_step == 0):
                val_metrics = self.eval()
                self.log(val_metrics)


        # torch.save(network.state_dict(), 'output/network.save')

        return sat

    @torch.no_grad()
    def eval(self):

        # switch to evaluation mode
        self.classifier.eval()

        self.mr.reset()

        with torch.no_grad():
            for i, (inputs, labels) in enumerate(self.val_loader):
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                # compute scores from images
                scores = self.classifier(inputs)
                if self.args.clamping:
                    scores = torch.clamp(scores, min=self.args.min_clamp, max=self.args.max_clamp)
                
                # compute both losses and the the difference of the log partitions
                ploss, bloss, log_diff = self.classifier.loss(y=labels, scores=scores, mp=True)

                # compute projected outputs
                outputs = self.classifier.predict().to(self.device)

                if self.args.log:
                    if self.args.record_leaves_acc:
                        leaves=self.classifier.hexL.hexg.leaves_mask
                    else:
                        leaves=None
                    # compute & record metrics
                    self.mr.recDict(self.compute_metrics(scores, outputs, labels, bloss, ploss, log_diff, leaves=leaves, prefix="val", device=self.device))
                    # mr.recValue("val/base_validity", hexL.hexg.checkValidity(states=torch.gt(scores, 0)), reduce=True)
                    # mr.recValue("val/hex_validity", hexL.hexg.checkValidity(states=outputs), reduce=True)
        
        if self.args.log:
            val_metrics = dict(self.mr.reduction())
            return val_metrics

        else:
            return None

    def log(self, metrics_dict, epoch=None):
        if epoch!=None:
            metrics_dict["epoch"] = epoch + self.args.epoch_offset
        if self.args.scheduler:
            metrics_dict["lr"] =  self.scheduler.get_lr()
        self.run.log(metrics_dict)


class MCTrainer(Trainer):
    def __init__(self, classifier=None, optimizer=None, scheduler=None,
                train_loader=None, val_loader=None, device=None, compute_metrics=None, wandb_run=None, args=None):
        super().__init__(classifier, optimizer, scheduler, train_loader, val_loader, device, compute_metrics, wandb_run, args)


    def step(self):
        
        # switch to train mode
        self.classifier.train()

        # init saturation counter and indicator
        sat=False
        j=0

        for i, (inputs, labels) in enumerate(self.train_loader):

            inputs, labels = inputs.to(self.device), labels.to(self.device)

            # compute scores from images
            scores = self.classifier(inputs)
            if self.args.clamping:
                scores = torch.clamp(scores, min=self.args.min_clamp, max=self.args.max_clamp)

            # compute both losses and the the difference of the log partitions
            ploss, bloss, log_diff = self.classifier.loss(y=labels, scores=scores)

            # compute projected outputs
            outputs = self.classifier.predict().to(self.device)

            # compute & log metrics
            if self.args.log_train:
                train_metrics = self.compute_metrics(scores, outputs, F.one_hot(labels, num_classes=outputs.shape[1]), bloss, ploss, log_diff, prefix="train", device=self.device)
                self.run.log(train_metrics)

            elif self.args.saturation_callout:
                train_metrics = {}
                train_metrics["train/max_score"] = torch.max(scores).item()
                train_metrics["train/min_score"] = torch.min(scores).item()


            if self.args.saturation_callout:
                if train_metrics["train/max_score"]==self.args.max_clamp or train_metrics["train/min_score"]==self.args.min_clamp:
                    j+=1
                    if j>20:
                        sat=True
                        break
                else:
                    j=0

            # compute gradient and do optimizer step
            self.optimizer.zero_grad()
            ploss.backward()

            # Run the optimizer
            self.optimizer.step()

            if self.args.scheduler:
                self.scheduler.step()

            if self.args.log_batch and (i % self.args.log_batch_step == 0):
                val_metrics = self.eval()
                self.log(val_metrics)
        # torch.save(network.state_dict(), 'output/network.save')

        return sat

    @torch.no_grad()
    def eval(self):

        # switch to evaluation mode
        self.classifier.eval()

        self.mr.reset()

        with torch.no_grad():
            for i, (inputs, labels) in enumerate(self.val_loader):
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                # compute scores from images
                scores = self.classifier(inputs)
                if self.args.clamping:
                    scores = torch.clamp(scores, min=self.args.min_clamp, max=self.args.max_clamp)
                
                # compute both losses and the the difference of the log partitions
                ploss, bloss, log_diff = self.classifier.loss(y=labels, scores=scores, mp=True)

                # compute projected outputs
                outputs = self.classifier.predict().to(self.device)

                if self.args.log:
                    self.mr.recDict(self.compute_metrics(scores, outputs, F.one_hot(labels, num_classes=outputs.shape[1]), bloss, ploss, log_diff, prefix="val", device=self.device))
                    # mr.recValue("val/base_validity", hexL.hexg.checkValidity(states=torch.gt(scores, 0)), reduce=True)
                    # mr.recValue("val/hex_validity", hexL.hexg.checkValidity(states=outputs), reduce=True)
        
        if self.args.log:
            val_metrics = dict(self.mr.reduction())
            return val_metrics

        else:
            return None