import torch
import numpy as np
import logging
import time
import torch.distributions as dists
import torch.nn.functional as F
import pandas as pd
from biggs.models.predictors import BaseCNN
from omegaconf import OmegaConf
import os
from biggs.data.utils.tokenize import Encoder

from typing import List

to_np = lambda x: x.cpu().detach().numpy()
to_list = lambda x: to_np(x).tolist()


def _mutagenesis_tensor(base_seq):
    base_seq = torch.squeeze(base_seq)  # Remove batch dimension.
    seq_len, vocab_len = base_seq.shape
    # Create mutagenesis tensor
    all_seqs = []
    for i in range(seq_len):
        for j in range(vocab_len):
            new_seq = base_seq.clone()
            new_seq[i][j] = 1
            all_seqs.append(new_seq)
    all_seqs = torch.stack(all_seqs)
    return all_seqs

def _calc_q_ij(logits, one_hot_seqs, not_mutated_indices):
    probs = F.softmax(logits, dim=-1)
    # Calculate probabilities of mutations under source.
    # We set probability one for indices that aren't mutated
    # [num_seq, L]
    q_ij = torch.sum(probs * one_hot_seqs, dim=-1) + not_mutated_indices
    # Log-sum-exp to avoid numerical instability
    # [num_seq]
    q_ij = torch.sum(torch.log(q_ij), dim=-1)
    q_ij = torch.exp(q_ij)
    return q_ij

def _calc_q_i(logits, mutated_indices, not_mutated_indices):
    probs = F.softmax(logits, dim=-1)
    q_i = torch.sum(torch.log(probs * mutated_indices + not_mutated_indices), dim=-1)
    q_i = torch.exp(q_i)
    return q_i

class GwgPairSampler(torch.nn.Module):
    
    def __init__(
            self,
            predictor_dir: str,
            edit_tolerance: float,
            residue_temperature: List[float],
            aa_temperature: List[float],
            hamming_distance: List[float],
            ckpt_name: str,
            criterion: str,
            use_gwg: bool,
            verbose: bool = False,
            gibbs_samples: int = 500,
            use_hessian: bool = False,
            device: str = "cuda",
        ):
        super().__init__()
        self._ckpt_name = ckpt_name
        self._criterion = criterion
        self._use_gwg = use_gwg
        self._log = logging.getLogger(__name__)
        self.device = torch.device(device)
        self._log.info(f'Using device: {self.device}')
        self.predictor_tokenizer =Encoder()
        self.predictor = self._setup_predictor(predictor_dir)
        self.num_tokens = len(self.predictor_tokenizer.alphabet)
        self.edit_tolerance = edit_tolerance 
        self.residue_temperature = residue_temperature
        self.aa_temperature = aa_temperature
        self.hamming_distance = hamming_distance
        self._num_hamming_bins = len(self.hamming_distance)+1
        self.total_pairs = 0
        self.num_current_src_seqs = 0
        self.gibbs_samples = gibbs_samples
        self.use_hessian = use_hessian
        self._verbose = verbose

    def _setup_predictor(self, predictor_dir: str):
        # Load model weights.
        predictor_path = os.path.join(predictor_dir, self._ckpt_name)
        mdl_info = torch.load(predictor_path, map_location=self.device)
        cfg_path = os.path.join(predictor_dir, 'config.yaml')
        with open(cfg_path, 'r') as fp:
            ckpt_cfg = OmegaConf.load(fp.name)
        predictor = BaseCNN(make_one_hot=False, **ckpt_cfg.model.predictor)
        state_dict = {k.replace('scorer.', ''): v for k, v in mdl_info['state_dict'].items()}
        predictor.load_state_dict(state_dict)
        predictor.eval()
        predictor.to(self.device)
        self._log.info(predictor)
        return predictor

    def tokenize_seqs(self, seqs):
        return self.gen_tokenizer.encode(seqs)

    def _calc_local_diff(self, seq_one_hot):
        # Construct local difference
        gx = torch.autograd.grad(self.predictor(seq_one_hot).sum(), seq_one_hot)[0]
        gx_cur = (gx * seq_one_hot).sum(-1)[:, :, None]
        delta_ij = gx - gx_cur
        seq_len, vocab_len = seq_one_hot.shape[1:]
        if self.use_hessian:
            # Second order GWG
            start_time = time.time()
            mutangenesis_batch = _mutagenesis_tensor(seq_one_hot)
            delta = (mutangenesis_batch - seq_one_hot)
            hx = torch.autograd.functional.hessian(self.predictor, seq_one_hot)
            second_order = torch.einsum(
                'bij,bijbpq,bpq->b', delta, hx, delta).reshape(seq_len, vocab_len)
            elapsed_time = time.time() - start_time
            self._log.info(f'Computed Hessian in {elapsed_time:.1f}s')
            delta_ij += second_order
        delta_ij = delta_ij
        delta_i = torch.max(delta_ij, dim=-1)[0]
        return delta_i, delta_ij

    def _gibbs_sampler(self, seq_one_hot):
        seq_len, num_tokens = seq_one_hot.shape[1:]
        delta_i, delta_ij = self._calc_local_diff(seq_one_hot)
        delta_i, delta_ij = delta_i[0], delta_ij[0]

        # One step of BiG sampling.
        def _big_sample(num_mutations, residue_temp, aa_temp):
            # Construct proposal distributions
            residue_proposals = dists.OneHotCategorical(logits = delta_i / residue_temp)
            aa_proposals = []
            for i in range(seq_len):
                aa_proposals.append(
                    dists.OneHotCategorical(logits = delta_ij[i] / aa_temp)
                )

            # [num_mutations, num_samples]
            r_ij = torch.argmax(
                residue_proposals.sample((self.gibbs_samples, num_mutations)),
                dim=-1
            )

            # [num_samples, L, 20]
            seq_token = torch.argmax(seq_one_hot, dim=-1)
            mutated_seqs = seq_token.repeat(self.gibbs_samples, 1)
            samples_per_residue = torch.bincount(r_ij.reshape(-1))
            for i in range(samples_per_residue.shape[0]):
                n = samples_per_residue[i]
                if n == 0:
                    continue
                m_ij = aa_proposals[i].sample((n,))
                m_ij = torch.argmax(m_ij, dim=-1)
                mutated_seqs[torch.where(r_ij == i)[0], i] = m_ij
            return mutated_seqs
        
        # One step of GWG sampling.
        def _gwg_sample(num_mutations, _, aa_temp):
            seq_len, num_tokens = delta_ij.shape
            # Construct proposal distributions
            gwg_proposal = dists.OneHotCategorical(logits = delta_ij.flatten() / aa_temp)
            r_ij = gwg_proposal.sample((self.gibbs_samples,)).reshape(
                self.gibbs_samples, seq_len, num_tokens)

            # [num_samples, L, 20]
            seq_token = torch.argmax(seq_one_hot, dim=-1)
            mutated_seqs = seq_token.repeat(self.gibbs_samples, 1)
            seq_idx, res_idx, aa_idx = torch.where(r_ij)
            mutated_seqs[(seq_idx, res_idx)] = aa_idx
            return mutated_seqs
        
        return _gwg_sample if self._use_gwg else _big_sample

    def _make_one_hot(self, seq, differentiable=False):
        seq_one_hot = F.one_hot(seq, num_classes=self.num_tokens)
        if differentiable:
            seq_one_hot = seq_one_hot.float().requires_grad_()
        return seq_one_hot

    def _evaluate_one_hot(self, seq):
        input_one_hot = self._make_one_hot(seq)
        model_out = self.predictor(input_one_hot)
        return model_out

    def _decode(self, one_hot_seq):
        return self.predictor_tokenizer.decode(one_hot_seq)
    
    def _metropolis_hasting(
            self, mutants, source_one_hot, delta_score, res_temp, aa_temp):
        source = torch.argmax(source_one_hot, dim=-1)

        # [num_seq, L]
        mutated_indices = mutants != source[None]
        not_mutated_indices = ~mutated_indices
        # [num_seq, L, 20]
        mutant_one_hot = self._make_one_hot(mutants, differentiable=True)
        mutated_one_hot = mutant_one_hot * mutated_indices[..., None]

        source_delta_i, source_delta_ij = self._calc_local_diff(source_one_hot[None])
        mutant_delta_i, mutant_delta_ij = self._calc_local_diff(mutant_one_hot)

        # Calculate ratio of residue probabilities
        q_i_source = _calc_q_i(
            source_delta_i/res_temp, mutated_indices, not_mutated_indices)
        q_i_mutant = _calc_q_i(
            mutant_delta_i/res_temp, mutated_indices, not_mutated_indices)
        q_i_ratio = q_i_mutant / q_i_source

        # Calculate ratio of amino acid probabilities
        q_ij_source = _calc_q_ij(
            source_delta_ij/aa_temp, mutated_one_hot, not_mutated_indices)
        q_ij_mutant = _calc_q_ij(
            mutant_delta_ij/aa_temp, mutated_one_hot, not_mutated_indices)
        q_ij_ratio = q_ij_mutant / q_ij_source

        accept_prob = torch.exp(delta_score)*q_ij_ratio*q_i_ratio
        mh_step = accept_prob < torch.rand(accept_prob.shape).to(self.device)
        return mh_step

    def _evaluate_mutants(
            self,
            *,
            mutants,
            score,
            source_one_hot,
            res_temp,
            aa_temp
        ):
        all_mutated_scores = self._evaluate_one_hot(mutants)
        delta_score = all_mutated_scores - score

        if self._criterion == 'absolute':
            accept_mask = delta_score > self.edit_tolerance
        elif self._criterion == 'mh':
            accept_mask = self._metropolis_hasting(
                mutants, source_one_hot, delta_score, res_temp, aa_temp)
        else:
            raise ValueError(f'Unknown criterion: {self._criterion}')
        accepted_x = to_list(mutants[accept_mask])
        accepted_seq = [self._decode(x) for x in accepted_x]
        accepted_score = to_list(all_mutated_scores[accept_mask])
        return pd.DataFrame({
            'mutant_sequences': accepted_seq,
            'mutant_scores': accepted_score,
        }), mutants[accept_mask]

    def compute_mutant_stats(self, source_seq, mutant_seqs):
        num_mutated_res = torch.sum(
            ~(mutant_seqs == source_seq[None]), dim=-1)
        num_mutated_bins = np.bincount(
            to_np(num_mutated_res), minlength=self._num_hamming_bins)
        return num_mutated_res, num_mutated_bins

    def forward(self, batch):
        seqs = batch['sequences']

        #Tokenize
        tokenized_seqs = self.predictor_tokenizer.encode(seqs).to(self.device)
        total_num_seqs = len(tokenized_seqs)

        # Sweep over hyperparameters
        all_mutant_pairs = []
        for i, (real_seq, token_seq) in enumerate(zip(seqs, tokenized_seqs)):
            start_time = time.time()

            # Cast as float to take gradients through
            seq_one_hot = self._make_one_hot(token_seq, differentiable=True)

            # Compute base score
            pred_score = self._evaluate_one_hot(token_seq[None]).item()

            # Construct Gibbs sampler
            sampler = self._gibbs_sampler(seq_one_hot[None])
            # proposed_hamming_counts = np.zeros(self._num_hamming_bins)
            # accepted_hamming_counts = np.zeros(self._num_hamming_bins)
            seq_pairs = []
            total_num_proposals = 0
            all_proposed_mutants = []
            all_accepted_mutants = []
            for res_temp in self.residue_temperature:
                for aa_temp in self.aa_temperature:
                    for num_mutations in self.hamming_distance:

                        # Sample mutants
                        proposed_mutants = sampler(num_mutations, res_temp, aa_temp)
                        num_proposals = proposed_mutants.shape[0]
                        total_num_proposals += num_proposals
                        proposed_num_edits, _ = self.compute_mutant_stats(
                            token_seq, proposed_mutants)
                        # proposed_hamming_counts += proposed_edit_bins
                        proposed_mutants = proposed_mutants[proposed_num_edits > 0]
                        all_proposed_mutants.append(to_np(proposed_mutants))

                        # Run Gibbs generation of pairs
                        sample_outputs, accepted_mutants = self._evaluate_mutants(
                            mutants=proposed_mutants,
                            score=pred_score,
                            source_one_hot=seq_one_hot,
                            res_temp=res_temp,
                            aa_temp=aa_temp,
                        )

                        all_accepted_mutants.append(to_np(accepted_mutants))
                        sample_outputs['source_sequences'] = real_seq
                        sample_outputs['source_scores'] = pred_score
                        _, accepted_edit_bins = self.compute_mutant_stats(
                            token_seq, accepted_mutants)
                        # accepted_hamming_counts += accepted_edit_bins
                        seq_pairs.append(sample_outputs)
                        if self._verbose:
                            num_pairs = len(sample_outputs)
                            print(
                                f'Res temp: {res_temp:.3f}, AA temp: {aa_temp:.3f}, '
                                f'Num mutations: {num_mutations}, '
                                f'Accepted: {num_pairs}/{num_proposals} ({num_pairs/num_proposals:.2f})'
                            )

            if len(seq_pairs) > 0:
                seq_pairs = pd.concat(seq_pairs).drop_duplicates(
                    subset=['source_sequences', 'mutant_sequences'],
                    ignore_index=True
                )
                all_mutant_pairs.append(seq_pairs)
            if self._verbose:
                elapsed_time = time.time() - start_time
                num_new_pairs = len(seq_pairs)
                all_proposed_mutants = np.concatenate(all_proposed_mutants, axis=0)
                proposed_res_freq = np.mean(
                    all_proposed_mutants != to_np(token_seq)[None], axis=0
                ).round(decimals=2)
                n_proposed = all_proposed_mutants.shape[0]

                all_accepted_mutants = np.concatenate(all_accepted_mutants, axis=0)
                accepted_res_freq = np.mean(
                    all_accepted_mutants != to_np(token_seq)[None], axis=0).round(decimals=2)
                n_accepted = all_accepted_mutants.shape[0]

                # TODO: Fix entropy calculation.
                # proposed_entropy = entropy(all_proposed_mutants, axis=0).round(decimals=2).tolist()
                # accepted_entropy = entropy(all_accepted_mutants, axis=0).round(decimals=2).tolist()
                print(
                    f'Done with sequence {i+1}/{total_num_seqs} in {elapsed_time:.1f}s. '
                    f'Accepted {num_new_pairs}/{total_num_proposals} ({num_new_pairs/total_num_proposals:.2f}) sequences. \n'
                    # f'Proposed mutations {proposed_hamming_counts.tolist()}. \n'
                    # f'Accepted mutations {accepted_hamming_counts.tolist()}. \n'
                    # f'Proposed sites (n={n_proposed}): {proposed_res_freq}. \n'
                    # f'Accepted sites (n={n_accepted}): {accepted_res_freq}.'
                )


        if len(all_mutant_pairs) == 0:
            return None
        return pd.concat(all_mutant_pairs).drop_duplicates(
            subset=['source_sequences', 'mutant_sequences'],
            ignore_index=True
        )
