import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer
import cooper

from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split

from utils import evaluate_prediction
from data import  DataModule


torch.autograd.set_detect_anomaly(True)
                                   
def construct_mlp(input_dim, output_dim, hidden_dim, hidden_layers):
    layers = []
    layers.append(torch.nn.Linear(input_dim, hidden_dim))
    layers.append(torch.nn.LeakyReLU())
    for _ in range(hidden_layers):
        layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
        layers.append(torch.nn.LeakyReLU())
    layers.append(torch.nn.Linear(hidden_dim, output_dim))
    return torch.nn.Sequential(*layers)

## optimization
class InvariantLoss(cooper.ConstrainedMinimizationProblem):
    def __init__(self, encoder, decoder):
        self.criterion = torch.nn.MSELoss()
        self.encoder = encoder
        self.decoder = decoder
        self.reg = MMDLoss()
        super().__init__(is_constrained=True)

    def closure(self, batch):
        x, z = batch[0], batch[1]
        x1, x2 = x[:, 0], x[:, 1]
        z1 = self.encoder(x1)
        x_hat1 = self.decoder(z1)
        z2 = self.encoder(x2)
        x_hat2 = self.decoder(z2)
        
        loss = self.criterion(x_hat1, x1) + self.criterion(x_hat2, x2)
        reg = self.reg(z1[:, [0, 2]], z2[:, [0,2]])
        
        
        return cooper.CMPState(loss=loss, eq_defect=reg, ineq_defect=None)

    

from utils import MMDLoss

class LitAE(LightningModule):
    def __init__(self, latent_dim, hidden_dim, hidden_layers, constrained_opt=False, invariant_indices=[0, 2]):
        super(LitAE, self).__init__()
        self.encoder = construct_mlp(latent_dim, latent_dim, hidden_dim, hidden_layers)
        self.decoder = construct_mlp(latent_dim, latent_dim, hidden_dim, hidden_layers)
        self.constrained_opt = constrained_opt
        
        if self.constrained_opt:
            # define the constraint optimization problem
            self.cmp = InvariantLoss(encoder=self.encoder, decoder=self.decoder)
            self.formulation = cooper.LagrangianFormulation(self.cmp)

            # Define the primal parameters and optimizer
            primal_optimizer = cooper.optim.ExtraSGD(self.parameters(), lr=1e-5)

            # Define the dual optimizer. Note that this optimizer has NOT been fully instantiated
            # yet. Cooper takes care of this, once it has initialized the formulation state.
            dual_optimizer = cooper.optim.partial_optimizer(cooper.optim.ExtraSGD, lr=1e-5/2)

            # Wrap the formulation and both optimizers inside a ConstrainedOptimizer
            self.coop = cooper.ConstrainedOptimizer(self.formulation, primal_optimizer, dual_optimizer)
            
            # deactivates automatic optimization
            self.automatic_optimization = False
        else:
            self.automatic_optimization = True
            self.invariant_indices = invariant_indices
            self.mmd_loss = MMDLoss()
        
        self.validation_outputs = {"zs": [], "zs_hat": []}
    
    def forward(self, x):
        return self.decoder(self.encoder(x))
    
    def training_step(self, batch, batch_idx):
        if self.constrained_opt:
            self.coop.zero_grad()
            lagrangian = self.formulation.composite_objective(self.cmp.closure,batch)
            self.formulation.custom_backward(lagrangian)
            self.coop.step(self.cmp.closure, batch)
            loss = self.cmp.state.loss
            self.log("train_loss", loss, prog_bar=True)
            self.log("dual", self.formulation.state()[-1], prog_bar=True)
            return loss
        else:
            x, z = batch[0], batch[1]
            x1, x2 = x[:, 0], x[:, 1]
            z1 = self.encoder(x1)
            x_hat1 = self.decoder(z1)
            z2 = self.encoder(x2)
            x_hat2 = self.decoder(z2)
            
            loss = torch.nn.functional.mse_loss(x_hat1, x1) \
                    + torch.nn.functional.mse_loss(x_hat2, x2) \
                    + 3.*self.mmd_loss(z1[:, self.invariant_indices], z2[:, self.invariant_indices])
            self.log("train_loss", loss, prog_bar=True)
            return loss
    
    def validation_step(self, batch, batch_idx):
        # batch: [bs, 2(x, z), 2(obs,int), 3]
        x, z = batch[0], batch[1]
        x1, x2 = x[:, 0], x[:, 1]
        z1 = self.encoder(x1)
        x_hat1 = self.decoder(z1)
        z2 = self.encoder(x2)
        x_hat2 = self.decoder(z2)
        
        m1, std1 = z1.mean(0), z1.std(0)
        m2, std2 = z2.mean(0), z2.std(0)
        
        loss = torch.nn.functional.mse_loss(x_hat1, x1) \
                + torch.nn.functional.mse_loss(x_hat2, x2) \
                + torch.nn.functional.mse_loss(m1[[0,2]], m2[[0,2]]) \
                + torch.nn.functional.mse_loss(std1[[0,2]], std2[[0,2]])
                # + self.mmd_loss(z1[:, 0, None], z2[:, 0, None]) \
                # + self.mmd_loss(z1[:, 2, None], z2[:, 2, None])
        self.validation_outputs["zs"].append(z)
        self.validation_outputs["zs_hat"].append(torch.stack([z1, z2], dim=1))
        self.log("val_loss", loss, prog_bar=True)
        self.log('mmd loss', self.mmd_loss(z1[:, self.invariant_indices], 
                                           z2[:, self.invariant_indices]), 
                 prog_bar=True)
        return loss
    
    
    def mcc(self, z, z_hat):
        corrs = torch.zeros(z.shape[1], device=z.device)
        for i in range(z.shape[1]):
            data = torch.stack([z[:, i], z_hat[:, i]], dim=1).T
            corrs[i] = torch.abs(torch.corrcoef(data)[0, 1])
        return corrs
    
    def r2(self, z, z_hat, input_indices=[0, 2]):
        r2_scores = np.zeros(z.shape[1])
        z_hat = z_hat.detach().cpu().numpy()
        z = z.detach().cpu().numpy()
        z_hat = StandardScaler().fit_transform(z_hat)
        z = StandardScaler().fit_transform(z)
        # eval_model = linear_model.LinearRegression(n_jobs=-1) 
        eval_model = MLPRegressor(hidden_layer_sizes=(64), max_iter=500)
        for i in range(z.shape[1]):
            (
                train_inputs,
                test_inputs,
                train_labels,
                test_labels,
            ) = train_test_split(z_hat[:, input_indices], z[:, i])
            data = [train_inputs, train_labels, test_inputs, test_labels]
            r2_scores[i] = evaluate_prediction(eval_model, r2_score, *data)
        return r2_scores
        

    def on_validation_epoch_end(self) -> None:
        zs = torch.cat(self.validation_outputs['zs'])
        zs_hat = torch.cat(self.validation_outputs['zs_hat'])
        zs = zs.reshape(-1, 3)
        zs_hat = zs_hat.reshape(-1, 3)

        r2_scores = self.r2(zs, zs_hat, input_indices=self.invariant_indices)
        for i in range(r2_scores.shape[0]):
            self.log(f"r2_{i}", r2_scores[i], prog_bar=True)
        print('r2', r2_scores)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer



if __name__ == "__main__":
    import pytorch_lightning as pl
    pl.seed_everything(41)
    ds = DataModule(batch_size=4000)
    
    model = LitAE(
        latent_dim=3, 
        hidden_dim=128, 
        hidden_layers=3, 
        constrained_opt=False,
        invariant_indices=[0, 2]).cuda() # hidden_dim=128, hidden_layers=12
    trainer = Trainer(max_epochs=10000, 
                      accelerator="auto",
                      log_every_n_steps=10,
                      check_val_every_n_epoch=200,
                      )
    trainer.fit(model, ds)

