import os
import sys
import logging
from tqdm import tqdm

import torch
import numpy as np

logger = logging.getLogger("TopoDiff.experiment.sampler")

#. model
from TopoDiff.model.diffusion import Diffusion
from TopoDiff.model.latent_diffusion import LatentDiffusion
from TopoDiff.model.aux_head import SCHead

#. config
from TopoDiff.config.config import model_config
from TopoDiff.config.latent_config import latent_model_config

#. data
from TopoDiff.data.structure import StructureBuilder

#. pdb
from myopenfold.np import protein

#. debug
from TopoDiff.utils.visualize import visualize_prot

class Sampler:
    def __init__(self, sampling_config):
    
        self.config = sampling_config
        if self.config['gpu'] is None:
            self.config['gpu'] = 0
        self.device = torch.device('cuda:%d' % self.config['gpu'])
        
        self._init_model()

        self.structure_builder = StructureBuilder()

    def _init_model(self):

        ckpt = torch.load(self.config['ckpt_path'], map_location='cpu')
        self.label_dict = ckpt['label_dict']

        self.config_diffusion = model_config(self.config['model_preset'])
        self.config_diffusion.Model.Diffuser.SO3.cache_dir = os.path.join(self.config['script_dir'], 'cache')

        self.config_latent = latent_model_config(self.config['model_preset'])
        self.config_latent.Data.common.normalize.mu = self.label_dict['label_latent_mu']
        self.config_latent.Data.common.normalize.sigma = self.label_dict['label_latent_std']

        self.config_sc_head = self.config_diffusion.Model.Aux_head.SC

        logger.info('Loading structure diffusion model...')
        self.model_diffusion = Diffusion(self.config_diffusion.Model)
        missing_keys, unexpected_keys = self.model_diffusion.load_state_dict(ckpt['diffusion'], strict=False)
        if len(missing_keys) > 0:
            logger.warning('Missing keys: {}'.format(missing_keys))
        if len(unexpected_keys) > 0:
            logger.warning('Unexpected keys: {}'.format(unexpected_keys))

        logger.info('Loading latent diffusion model...')
        self.model_latent = LatentDiffusion(self.config_latent)
        missing_keys, unexpected_keys = self.model_latent.load_state_dict(ckpt['latent_diffusion'], strict=False)
        if len(missing_keys) > 0:
            logger.warning('Missing keys: {}'.format(missing_keys))
        if len(unexpected_keys) > 0:
            logger.warning('Unexpected keys: {}'.format(unexpected_keys))

        logger.info('Loading designability prediction model...')
        self.model_sc = SCHead(self.config_sc_head)
        missing_keys, unexpected_keys = self.model_sc.load_state_dict(ckpt['sc_head'], strict=False)
        if len(missing_keys) > 0:
            logger.warning('Missing keys: {}'.format(missing_keys))
        if len(unexpected_keys) > 0:
            logger.warning('Unexpected keys: {}'.format(unexpected_keys))

        self.model_diffusion.to(self.device)
        self.model_latent.to(self.device)
        self.model_sc

        self.model_diffusion.eval()
        self.model_latent.eval()
        self.model_sc.eval()

        self.latent_dim = self.config_diffusion.Model.Embedder_v2.topo_embedder.embed_dim

    
    def run(self):
        with torch.no_grad():
            # latent sampling
            logger.info('Start sampling latent...')
            self.rejection_sampling()
            logger.info('Finish sampling latent...')
            self.save_latent_result()

            # structure sampling
            logger.info('Start sampling structure...')
            self.strcuture_sampling()
            logger.info('Finish sampling structure...')

    def strcuture_sampling(self):

        with tqdm(total=self.num_total_sample) as pbar:
            for i in range(self.num_total_sample):
                if self.config['seed'] is not None:
                    sd = self.config['seed'] + i
                    self._set_seed(sd)
                else:
                    sd = None
                length = int(self.sampled_length[i].item())
                latent = self.sampled_latent[i]
                rep = self.sampled_idx_in_length[i].item()

                pbar.set_description('Sampling structure for length %d, repeat %d' % (length, rep))

                prediction = self.model_diffusion.sample_latent_conditional(
                        latent = latent,
                        return_traj = True,
                        return_frame = False,
                        return_position = True,
                        reconstruct_position = True,
                        num_res = length,
                        timestep = 200,
                        )
                coord37_record, coord37_mask = self.structure_builder.coord14_to_coord37(prediction['coord_hat'], trunc=True)
                prot_traj = self.structure_builder.get_coord_traj(coord37_record[None],
                                    aa_mask=coord37_mask,
                                    label_override='length %d, repeat %d' % (length, rep) if sd is None else 'length %d, repeat %d, seed %d' % (length, rep, sd),
                                    default_res='G'
                                    )
                save_path = os.path.join(self.config['output_dir'], 'length_%d' % length, 'sample_%d.pdb' % rep)
                if not os.path.exists(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path), exist_ok=True)
                with open(save_path, 'w') as f:
                    f.write(protein.to_pdb(prot_traj[0]))

                pbar.update(1)

    def _init_sample_goal_list(self):
        self.num_total_sample = len(range(self.config['length_start'], self.config['length_end']+1, self.config['length_step'])) * self.config['sample_per_length']
        self.sampled_latent = torch.zeros((self.num_total_sample, self.latent_dim))
        self.sampled_length = torch.zeros((self.num_total_sample))
        self.sampled_length_pred = torch.zeros((self.num_total_sample))
        self.sampled_sc_pred = torch.zeros((self.num_total_sample))
        self.sampled_idx_in_length = torch.zeros((self.num_total_sample))
        self.num_total_sampled = 0

    def _sample_new_latent(self):
        latent_cache = self.model_latent.sample(n_sample = self.config['latent_batch_size'])
        latent_cache = {k: v.to('cpu') for k, v in latent_cache.items()}

        # pred sc
        sc_pred = self.model_sc(latent_cache['latent_sample'])

        # pred length
        dis_mat = torch.sum((latent_cache['latent_sample'][:, None, :] - self.label_dict['label_latent'][None, :, :])**2, dim=-1)
        min_val, min_idx = torch.topk(dis_mat, k=self.config['length_k'], dim=-1, largest=False, sorted=True)
        length_pred = self.label_dict['label_length'][min_idx].float().mean(dim=-1).long()

        self.latent_cache_dict = {
            'latent': latent_cache['latent_sample'],
            'sc_pred': sc_pred,
            'length_pred': length_pred,
        }

        self.sample_idx = 0

    def _get_next_sample(self):
        if self.sample_idx >= self.config['latent_batch_size']:
            print('sample new latent')
            self._sample_new_latent()

        sample = {key: val[self.sample_idx] for key, val in self.latent_cache_dict.items()}
        self.sample_idx += 1
        self.num_total_sampled += 1

        return sample
    
    def rejection_sampling(self):
        if self.config['seed'] is not None:
            self._set_seed(self.config['seed'])
        self._init_sample_goal_list()
        self._sample_new_latent()

        with tqdm(total=self.num_total_sample) as pbar:
            num_toal_accepted = 0
            for length in range(self.config['length_start'], self.config['length_end']+1, self.config['length_step']):
                num_accepted = 0
                while num_accepted < self.config['sample_per_length']:
                    pbar.set_description('Sampling latent for length %d, repeat %d' % (length, num_accepted))

                    sample = self._get_next_sample()
                    # print(sample)
                    if (sample['length_pred'] * (1 - self.config['length_epsilon']) <= length and
                    sample['length_pred'] * (1 + self.config['length_epsilon']) >= length and
                    sample['sc_pred'] >= self.config['min_sc'] and
                    sample['sc_pred'] <= self.config['max_sc']):
                        self.sampled_latent[num_toal_accepted] = sample['latent']
                        self.sampled_length[num_toal_accepted] = length
                        self.sampled_length_pred[num_toal_accepted] = sample['length_pred']
                        self.sampled_sc_pred[num_toal_accepted] = sample['sc_pred']
                        self.sampled_idx_in_length[num_toal_accepted] = num_accepted
                        num_toal_accepted += 1
                        num_accepted += 1
                        pbar.update(1)
        
        logger.info('Successfully accepted %d latents' % num_toal_accepted)
        logger.info('Total sampled %d latents' % self.num_total_sampled)

    def save_latent_result(self):
        if not os.path.exists(self.config['output_dir']):
            os.makedirs(self.config['output_dir'], exist_ok=True)
        save_path = os.path.join(self.config['output_dir'], 'latent_dict.pt')

        latent_dict = {
            'latent': self.sampled_latent[:self.num_total_sampled],
            'length': self.sampled_length[:self.num_total_sampled],
            'length_pred': self.sampled_length_pred[:self.num_total_sampled],
            'sc_pred': self.sampled_sc_pred[:self.num_total_sampled],
            'idx_in_length': self.sampled_idx_in_length[:self.num_total_sampled],
        }
        torch.save(latent_dict, save_path)
        logger.info('Saving sample result to %s' % save_path)

    def _set_seed(self, sd):
        torch.manual_seed(sd)
        np.random.seed(sd)