import os
import torch
from torch import optim, nn, utils, Tensor
from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
from utils import count_parameters, normalized_mse_loss
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import sys
sys.path.append('/home/user/SL2equivariance') 
from generate_data import normalize_pair

class PLModel(pl.LightningModule):
    def __init__(self, net, loss_fn=nn.functional.mse_loss, lr=3e-4,
                    use_lr_scheduler = False, equiv_function=None, additional_loss_function=None, use_eval_mode=False, normalize_val=False):
        super().__init__()
        self.save_hyperparameters()
        self.net = net
        self.num_param = count_parameters(net)
        self.loss_fn = loss_fn
        self.lr = lr
        self.use_lr_scheduler = use_lr_scheduler
        # self.equiv_function(x, pred) is call format
        if equiv_function is None:
            self.do_equiv_loss = False
        else:
            self.do_equiv_loss = True
            self.equiv_function = equiv_function
        self.use_eval_mode = use_eval_mode
        self.additional_loss_function = additional_loss_function
        self.normalize_val = normalize_val # only used for transform validation
        self.train_losses = []
        
    def forward(self, batch):
        return self.net(batch)

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self.net(x)
        loss = self.loss_fn(pred, y)
        
        with torch.no_grad():
            normalized_loss = normalized_mse_loss(pred, y)
        self.log("train_loss", loss, on_epoch=True) 
        self.train_losses.append(loss.detach().data)
        
        self.log("train_loss_normalized", normalized_loss, on_epoch=True) 
        if self.additional_loss_function is not None:
            train_additional_loss = self.additional_loss_function(pred, y)
            self.log("train_additional_loss", train_additional_loss, on_epoch=True)

        # Below could be used as a more advanced option, if want to log per-epoch instead of per-batch
        # See e.g. https://www.pytorchlightning.ai/blog/tensorboard-with-pytorch-lightning
        """logs = {'train_loss': loss}
        batch_dictionary = {"loss": loss, "log": logs}"""

        return loss

    def validation_step(self, batch, batch_idx):
        # PL automatically turns off gradient computation, dropout, batchnorm :)
        x, y = batch
        if self.use_eval_mode:
            eval_mode = True
        else:
            eval_mode = False
        pred = self.net(x, eval_mode=eval_mode)
        loss = self.loss_fn(pred, y)

        if self.do_equiv_loss:
            with torch.no_grad():
                if pred.dtype != torch.complex64 and pred.dtype != torch.complex:
                    mse_loss = nn.functional.mse_loss(pred, y)
                x_transformed, expected_pred_transformed = self.equiv_function(x, pred) #, thresh=np.sqrt(3)) #thresh included via wrapper in equiv_function now
                
                # normalize x_transformed, expected_transformed
                if self.normalize_val:
                    x_transformed, expected_pred_transformed = normalize_pair(x_transformed, expected_pred_transformed)

                pred_transformed = self.net(x_transformed, eval_mode=True) # eval_mode here should always be true
                equiv_loss = self.loss_fn(pred_transformed, expected_pred_transformed)
                if self.additional_loss_function is not None:
                    val_additional_loss_on_transformed = self.additional_loss_function(pred_transformed, expected_pred_transformed)
                    self.log("val_additional_loss_on_transformed", val_additional_loss_on_transformed, on_epoch=True)

                # overwrite previous x_transformed and pred_transformed!
                x_transformed, y_transformed = self.equiv_function(x, y)

                pred_transformed_before_norm = self.net(x_transformed, eval_mode=True)
                self.log("val_loss_on_transformed_unnormalized", self.loss_fn(pred_transformed_before_norm, y_transformed), on_epoch=True)
                self.log("val_numerator_loss_unnormalized", torch.norm(pred_transformed_before_norm - y_transformed), on_epoch=True)
                # normalize x_transformed, y_transformed
            
                if self.normalize_val:
                    x_transformed, y_transformed = normalize_pair(x_transformed, y_transformed)

                pred_transformed = self.net(x_transformed, eval_mode=True)
                val_loss_on_transformed = self.loss_fn(pred_transformed, y_transformed)

                if pred.dtype != torch.complex64 and pred.dtype != torch.complex:
                    val_unnorm_loss_on_transformed = nn.functional.mse_loss(pred_transformed, y_transformed)
                    self.log("val_unnorm_loss_on_transformed", val_unnorm_loss_on_transformed, on_epoch=True)

                self.log("val_norm_of_pred", torch.norm(pred), on_epoch=True)
                self.log("val_norm_of_y", torch.norm(y), on_epoch=True)
                self.log("val_norm_of_pred_minus_y", torch.norm(pred - y), on_epoch=True)
                if type(x) == torch.tensor:
                    self.log("val_norm_of_x_should_const", torch.norm(x), on_epoch=True)
                self.log("val_norm_of_pred_transformed", torch.norm(pred_transformed), on_epoch=True)
                self.log("val_norm_of_y_transformed", torch.norm(y_transformed), on_epoch=True)
                self.log("val_norm_of_pred_minus_y_transformed", torch.norm(pred_transformed-y_transformed), on_epoch=True)

                self.log("val_loss_on_transformed", val_loss_on_transformed, on_epoch=True)
                
            self.log("log10_val_equivariance_loss", torch.log10(equiv_loss), on_epoch=True) 
            if pred.dtype != torch.complex64 and pred.dtype != torch.complex:
                self.log("mse_loss", mse_loss, on_epoch=True) 



        if self.additional_loss_function is not None:
            val_additional_loss = self.additional_loss_function(pred, y)
            self.log("val_additional_loss", val_additional_loss, on_epoch=True)
        self.log("val_loss", loss, on_epoch=True) 
        self.log("num_param", self.num_param, on_epoch=True)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        #optimizer = optim.RMSprop(self.parameters(), lr = 1e-3)#, lr= self.lr, momentum = .9)
        out_dict = {"optimizer":optimizer}
        if self.use_lr_scheduler:
                out_dict["lr_scheduler"] = {
                    "scheduler": ReduceLROnPlateau(optimizer),
                    "monitor": "train_loss",
                    "frequency": 1
                        # If "monitor" references validation metrics, then "frequency" should be set to a
                        # multiple of "trainer.check_val_every_n_epoch".
                }

        return out_dict
