import os
import torch
import torch.nn as nn
import hydra
import random
import numpy as np
from typing import Any, Literal
import pytorch_lightning as pl
from omegaconf import DictConfig
from torchmetrics import Accuracy
from affinityenhancer.data.datasets.gearnet_data_utils \
    import protein_sequence_decoder
from affinityenhancer.paratope_encoder.pointnet import SpatialSurfaceEncoder

class Base(pl.LightningModule):
    def __init__(self,
                 structure_encoder: DictConfig,
                 sequence_decoder: DictConfig,
                 training: DictConfig,
                 ignore_index: int = 20):
        super().__init__()
        self.training_cfg = hydra.utils.instantiate(training)
        
        self.encoder = hydra.utils.instantiate(structure_encoder)
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.decoder = hydra.utils.instantiate(sequence_decoder)
        self.ignore_index = ignore_index
        self.loss = nn.CrossEntropyLoss(reduction='none', 
                                        ignore_index=ignore_index)
        self.accuracy = Accuracy(task='multiclass',
                                 num_classes=self.decoder.num_classes,
                                 ignore_index=ignore_index
                                 )
        self.save_hyperparameters()

    
    def calculate_edit_distance(self, sequence_tensor, reference,
                                verbose=False):
        per_aa_equal_array = np.equal(sequence_tensor, reference)
        # zero out pad ids
        per_aa_equal_array = np.where(reference == self.ignore_index, True,
                                      per_aa_equal_array)
        edit_distance = np.sum(np.logical_not(per_aa_equal_array.astype(int)),
                               axis=-1)
        if verbose:
            print('ref: ', reference[0])
            #print('pred: ', edit_distance[0], sequence_tensor[0])
            #edit_distance = edlib.align(undeco_seed, undeco_sample)["editDistance"]
        del sequence_tensor
        del reference
        return edit_distance


    def do_step(self, batch, data_label, batch_idx):
        inputs, labels = batch
        
        mask = torch.zeros(labels.shape).type_as(labels) #leave these alone
        mask[labels == self.ignore_index] = 1 #ignore these

        output = self(inputs, mask=mask.bool())
        output = output
        B, L, C = output.shape
        loss = self.loss(output.view(B*L, C),
                         labels.flatten())
        
        loss = loss.reshape(B, L) #* (~mask.bool()).long()
        #Sum over non-masked i.e. non-padded residues only
        loss = loss[mask==0].sum() # / float(mask[mask == 0].shape[0])
        
        self.log(f"{data_label}/loss",
                 loss,
                 on_step=True,
                 on_epoch=True,
                 prog_bar=True,
                 logger=True,
                 sync_dist=True
                 )
        
        if data_label != 'val' and (batch_idx % 50 == 0) and \
            (self.current_epoch % 10 == 0):
            seqs = self.decode_sequence(output.clone().detach().argmax(-1).cpu().numpy().tolist()[:4])
            seqs_input = self.decode_sequence(labels.clone().detach().cpu().numpy().tolist()[:4])
            seqs = [''.join([ps if rs != 'X' else 'X' 
                             for ps, rs in zip(pseq, rseq)])
                    for pseq, rseq in zip(seqs, seqs_input)
                        ]
            self.logger.log_text(key=f'Sequences_{data_label}', 
                                 columns=['epoch', 'Input', 'Output'],
                            data=[[self.current_epoch,
                                    ' ; '.join([t for t in seqs_input if not t is None]),
                                    ' ; '.join([t for t in seqs if not t is None])
                                    ]]
                                )

        accuracy = self.accuracy(output.view(-1, output.shape[-1]).clone().detach(),
                         labels.flatten())
        
        edit_distance = self.calculate_edit_distance(output.clone().detach().cpu().argmax(-1).numpy(),
                                                     labels.clone().detach().cpu().numpy()
                                                     )
        self.log_dict({f"{data_label}/ED_mean": edit_distance.mean(),
                       f"{data_label}/ED_sd": edit_distance.std(),
                       f"{data_label}/accuracy": accuracy},
                    on_step=True,
                    on_epoch=True,
                    prog_bar=False,
                    logger=True,
                    sync_dist=True
                    )

        return loss
        

    def decode_sequence(self, output):
        return protein_sequence_decoder(output)


    def training_step(self, batch, batch_idx):
        return self.do_step(batch, 'train', batch_idx)


    def validation_step(self, batch, batch_idx):
        return self.do_step(batch, 'val', batch_idx)


    def configure_optimizers(self):
        optimizer_adam = \
            torch.optim.Adam(self.parameters(),
                             lr=self.training_cfg.lr, weight_decay=self.training_cfg.weight_decay)
        
        lr_scheduler = {
            'scheduler': \
            torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_adam,
                                                        verbose=True,
                                                        min_lr=0.00001,
                                                        patience=self.training_cfg.patience,
                                                        factor=0.1),
            'monitor': 'val/loss', #apply on validation loss
            'name': 'lr'
            }
        
        return  {'optimizer': optimizer_adam, 'lr_scheduler': lr_scheduler}
    


class Gearnet_MLP(Base):
    def __init__(self,
                 structure_encoder: DictConfig,
                 sequence_decoder: DictConfig,
                 training: DictConfig,
                 ignore_index: int = 20):
        super().__init__(structure_encoder,
                         sequence_decoder,
                         training,
                         ignore_index=ignore_index
                         )
        
    def forward(self, inputs, mask=None, ret_embedding=False):
        embedding = self.encoder(inputs)
        output = self.decoder(embedding)
        if ret_embedding:
            return output, embedding
        return output


class Gearnet_Transformer(pl.LightningModule):
    def __init__(self,
                 structure_encoder: DictConfig,
                 sequence_decoder: DictConfig,
                 training: DictConfig,
                 ignore_index: int = 20):
        super().__init__(structure_encoder,
                         sequence_decoder,
                         training,
                         ignore_index=ignore_index
                         )

    def forward(self, inputs, mask=None, ret_embedding=False):
        embedding = self.encoder(inputs)
        output = self.decoder(embedding, mask=mask)
        if ret_embedding:
            return output, embedding
        return output


    def training_step(self, batch, batch_idx):
        return self.do_step(batch, 'train', batch_idx)


    def validation_step(self, batch, batch_idx):
        return self.do_step(batch, 'val', batch_idx)


    def configure_optimizers(self):
        optimizer_adamw = \
            torch.optim.AdamW(self.parameters(),
                             lr=self.training_cfg.lr, 
                             weight_decay=self.training_cfg.weight_decay)
    
        lr_scheduler_trans = {
                "scheduler": torch.optim.lr_scheduler.LinearLR(
                    optimizer_adamw,
                    start_factor=self.training_cfg.lr_start_factor,
                    end_factor=1.0,
                    total_iters=self.training_cfg.warmup_batches,
                ),
                "frequency": 1,
                "interval": "step",
            }
        
        return  {'optimizer': optimizer_adamw,
                 'lr_scheduler': lr_scheduler_trans}
    



class IdentityEncoder(nn.Module):
    def __init__(self,
                 ignore_index: int = 20
                 ):
        self.ignore_index = ignore_index
    
    def forward(self, inputs, **kwargs):
        return self.encoder(inputs)
    
    def encoder(self, inputs, **kwargs):
        return inputs

    def decoder(self, latents, **kwargs):
        return latents




class General_Embedder_MLP(Base):
    def __init__(self,
                 structure_encoder: DictConfig,
                 sequence_decoder: DictConfig,
                 training: DictConfig,
                 ignore_index: int = 20):
        super().__init__(structure_encoder,
                         sequence_decoder,
                         training,
                         ignore_index=ignore_index
                         )
        
    def forward(self, inputs, mask=None, ret_embedding=False):
        embedding = self.encoder(inputs)
        output = self.decoder(embedding)
        if ret_embedding:
            return output, embedding
        return output
