from typing import Optional, Union, Literal
import torch
import pytorch_lightning as pl
from affinityenhancer.propen.propen import BasePropEnModel
from einops import repeat
from affinityenhancer.propen.utils import calculate_edit_distance
from affinityenhancer.data.constants import MAX_HEAVY_LEN

def enhance(batch: torch.tensor,
            enhancer: Union[torch.nn.Module, pl.LightningModule, BasePropEnModel],
            iterations: int = 5,
            save_trajectory: bool = False,
            enhance_mode: Literal["latents", "inputs"] = "latents",
            mask: Optional[torch.tensor] = None
            ):
    if enhancer.model.__class__.__name__ in ["GraphTransformerEncoder"]:
        print('Here')
        inputs, edges, labels = batch
        edges = edges.to(enhancer.device)
        inputs = inputs.to(enhancer.device)
        labels = labels.to(enhancer.device)
        batch = (inputs, edges, labels)
    else:
        inputs, labels = batch
        inputs = inputs.to(enhancer.device)
        labels = labels.to(enhancer.device)
        batch = (inputs, labels)
        edges=None
    
    if save_trajectory:
        b, l = inputs.shape
        x_hats = torch.zeros((iterations, b, l)).type_as(inputs)
    z_hat, x_hat = enhancer.infer(batch, mask_region=mask)
    if save_trajectory:
        x_hats[0, :, :] = x_hat
    for i in range(iterations-1):
        args_infer = {'mask_region': mask} \
            if enhancer.model.__class__.__name__ not in ["GraphTransformerEncoder"] else {'mask_region':mask, 'edges': edges}
        if enhance_mode == "latents":
            z_hat, x_hat = enhancer.infer_from_latents(z_hat, **args_infer)
        elif enhance_mode == "inputs":
            batch = (x_hat, labels) if edges is None else (x_hat, edges, labels)
            _, x_hat = enhancer.infer(batch, mask_region=mask)
            print(x_hat)
        if save_trajectory:
            x_hats[i, :, :] = x_hat
    if save_trajectory:
        return x_hats[-1], x_hats
    # set "X" to infinity
    x_hat[:, :, -1] = -float('inf')
    return x_hat

@torch.no_grad()
def decode(enhancer: Union[torch.nn.Module, pl.LightningModule, BasePropEnModel],
           x_hat: torch.tensor,
           labels: torch.tensor,
           x_hats: Optional[torch.tensor] = None,
           sample_mode: Literal["logits", "decoded", "categorical"] = "logits",
           temp: float = 0.1,
           samples: int = 100,
           mask: Optional[torch.tensor] = None
           ):
    if sample_mode == "decoded":
        x_hat = x_hat.argmax(-1)
    else:
        if sample_mode == "logits":
            x_hat_distr = torch.nn.functional.softmax(x_hat / temp, dim=-1)
            samples_out = []
            for b in range(x_hat_distr.shape[0]):
                samples_out.append(torch.multinomial(x_hat_distr[b], samples, replacement=True).permute(1, 0))
            print(len(samples_out), samples_out[0].shape)
        elif sample_mode == "categorical":
            samples_out = []
            for b in range(x_hat_distr.shape[0]):
                x_hat_distr = torch.distributions.categorical.Categorical(x_hat / temp)
                xhat_samples = torch.stack([x_hat_distr.sample() for _ in samples], dim=0)
                samples_out.append(xhat_samples)
            
        x_hat = torch.stack(samples_out, dim=0)
        b, _, l = x_hat.shape
        x_hat = x_hat.view(b*samples, l)

    b, l = labels.shape
    samples = 1 if sample_mode=='decoded' else samples
    labels = repeat(labels, 'b l -> b repeat l', repeat=samples).reshape(b * samples, l)
    x_hat = x_hat.cpu()
    ed_to_input_sampled = calculate_edit_distance(x_hat.numpy(),
                                                  labels.numpy(),
                                                  enhancer.ignore_index
                                                  )
    sequences = enhancer.decode_to_sequence(x_hat.numpy().tolist())
    sequences_wt = enhancer.decode_to_sequence(labels.numpy().tolist())
    if hasattr(enhancer, 'autoencoder'):
        if enhancer.autoencoder.encoder.by_chain:
            max_heavy_len = enhancer.autoencoder.encoder.max_heavy_len
            sequences_wt = [[seq[:max_heavy_len].replace('X', ''),
                            seq[max_heavy_len:].replace('X', '')]
                            for seq in sequences_wt]
            sequences = [[seq[:max_heavy_len][:len(seqwt[0])], seq[max_heavy_len:][:len(seqwt[1])]]
                                for seq, seqwt in zip(sequences, sequences_wt)]    
        else:
            sequences = [''.join([ps if rs != 'X' else ''
                                for ps, rs in zip(pseq, rseq)])
                                for pseq, rseq in zip(sequences, sequences_wt)
                                ]
            sequences_wt = [rseq.replace('X', '')
                            for rseq in sequences_wt]
    else:
        max_heavy_len = MAX_HEAVY_LEN
        sequences_wt = [[seq[:max_heavy_len].replace('X', ''),
                        seq[max_heavy_len:].replace('X', '')]
                        for seq in sequences_wt]
        sequences = [[seq[:max_heavy_len][:len(seqwt[0])], seq[max_heavy_len:][:len(seqwt[1])]]
                            for seq, seqwt in zip(sequences, sequences_wt)] 
    sequences_wt = [sequences_wt[i] for i in range(0, len(sequences_wt), samples)]

    if x_hats is not None:
        x_hats = x_hats.argmax(-1)
        ed_to_input_trajectories = \
            calculate_edit_distance(x_hats.numpy(),
                                    labels.numpy(),
                                    enhancer.ignore_index
                                    )

        return sequences, sequences_wt, ed_to_input_sampled, ed_to_input_trajectories
    
    return sequences, sequences_wt, ed_to_input_sampled


def sample(batch: torch.tensor,
           enhancer: Union[torch.nn.Module, pl.LightningModule, BasePropEnModel],
           iterations: int = 5,
           save_trajectory: bool = False,
           enhance_mode: Literal["latents", "inputs"] = "latents",
           sample_mode: Literal["logits", "decoded"] = "logits",
           temp: float = 0.1,
           samples: int = 100,
           mask: Optional[torch.tensor] = None
           ):
    outputs = enhance(batch,
                      enhancer,
                      iterations=iterations,
                      save_trajectory=save_trajectory,
                      enhance_mode=enhance_mode,
                      mask=mask
                      )
    x_hat = outputs if not save_trajectory else outputs[0]
    x_hats = None if not save_trajectory else outputs[1]
    return decode(enhancer,
                  x_hat,
                  batch[-1],
                  x_hats=x_hats,
                  sample_mode=sample_mode,
                  temp=temp,
                  samples=samples,
                  mask=mask
                  )
