import torch
from torch import nn
import torch
import lightning as L
import numpy as np 

class R_model(L.pytorch.LightningModule):

    def __init__(self,
                 loss_fn,
                 sample,
                 dropout=0,
                 lr=0.0005,
                 n_input_channels=22,
                 features_dim=64,):
 

        super().__init__()
        self.features_dim = features_dim
        self.lr = lr
        self.dropout = dropout
        self.out_shape = 1
        self.n_input_channels = n_input_channels
        self.loss_fn = loss_fn
        self.build_model(sample)
        self.init_weights()

    def build_model(self,sample):

        self.cnn = nn.Sequential(
                    nn.Conv2d(3, 32, (2, 2)),
                    nn.ReLU(),
                    nn.Flatten(),
                )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            sample = sample[0].unsqueeze(0)
            n_flatten = self.cnn(sample).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten+3, self.features_dim), 
                                    nn.ReLU(),
                                    nn.Dropout(self.dropout),
                                    nn.Linear(self.features_dim, self.out_shape),
                                    nn.Sigmoid())


    def init_weights(self):
        """ Initialize the network parameters
        """
        for m in self.cnn + self.linear:
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                nn.init.orthogonal_(m.weight)
                nn.init.constant_(m.bias, 0)


    def forward(self, grid1,inventory1):
        """ Run forward pass
        """
        
        features1  = self.cnn(grid1)
        rewards1 = self.linear(torch.cat([features1,inventory1], dim=1))

        return rewards1
    
    
    def step(self, batch, batch_idx):
        x, Y = batch
        y_hat = [self(x[0],x[2]),self(x[1],x[3])]
        
        Y[torch.where(Y > 1)[0]] = 0.5
        loss = self.loss_fn(*y_hat, Y)
        reward_argmax = np.argmax(torch.cat(y_hat, dim=-1).detach().cpu().numpy(), axis=1)
        Y = Y.cpu().numpy()
        acc = np.mean(reward_argmax[Y!= 0.5] == Y[Y != 0.5])
    
        return loss, acc
    
    def training_step(self, batch, batch_idx):
        """ Training step
        """
        loss,acc = self.step(batch, batch_idx)
        
        self.log("train_loss", loss, on_epoch=True)
        self.log("train_acc", acc, on_epoch=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        """ Validation step
        """

        loss,acc = self.step(batch, batch_idx)
        
        self.log("val_loss", loss,sync_dist=True)
        self.log("val_acc", acc,sync_dist=True)
        
        return loss
    
    def configure_optimizers(self):
        """ Configure optimizer
        """
        print("lr : ", self.lr)
        optimizer = torch.optim.Adam(list(self.cnn.parameters()) + list(self.linear.parameters()), lr=self.lr)
        
        return optimizer
    
    def predict_step(self, batch, batch_idx):
        """ Test step
        """
        x = batch

        y_hat = self(x[0],x[1])
        
        return y_hat
    
    
class R_goal_model(R_model):
    def build_model(self,sample):

        self.cnn = nn.Sequential(
                    nn.Conv2d(3, 32, (2, 2)),
                    nn.ReLU(),
                    nn.Conv2d(32, 64, (2, 2)),
                    nn.ReLU(),
                    nn.Flatten(),
                )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            sample = sample[0].unsqueeze(0)
            n_flatten = self.cnn(sample).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten+3, self.features_dim), 
                                    nn.ReLU(),
                                    nn.Dropout(self.dropout),
                                    nn.Linear(self.features_dim, self.out_shape),
                                    nn.Sigmoid())
    def step(self, batch, batch_idx):
        x, Y = batch
        index_non_zero = torch.argwhere(Y)
        count_non_zero = index_non_zero.shape[0]

        index_zero = torch.argwhere(1-Y)
        used_index_zero = index_zero[torch.randperm(index_zero.shape[0])[:count_non_zero]]


        used_index = torch.cat([index_non_zero, used_index_zero]).squeeze()
        
        y_hat = self(x[0][used_index],x[1][used_index]).squeeze()
        loss = self.loss_fn(y_hat, Y[used_index])
       
        acc = torch.mean(torch.abs(y_hat-Y[used_index])<= 0.5, dtype=torch.float32)
        if count_non_zero == 0:
            acc = 0
        return loss, acc
    
    

    
    