import torch
import numpy as np
import json, copy
import random
import glob
import csv
import os
import re
import Bio.PDB
import pickle
import torch.nn.functional as F
from tqdm import tqdm
from collections import defaultdict
from bindgen.utils import full_square_dist
import ipdb

RESTYPE_1to3 = {
     "A": "ALA", "R": "ARG", "N": "ASN", "D": "ASP", "C": "CYS", "Q": "GLN","E": "GLU", "G": "GLY", "H": "HIS", "I": "ILE", "L": "LEU", "K": "LYS", "M": "MET", "F": "PHE", "P": "PRO", "S": "SER", "T": "THR", "W": "TRP", "Y": "TYR", "V": "VAL",
}

ALPHABET = ['#', 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
ATOM_TYPES = [
    '', 'N', 'CA', 'C', 'O', 'CB', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
    'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
    'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
    'CZ3', 'NZ', 'OXT'
]
RES_ATOM14 = [
    [''] * 14,
    ['N', 'CA', 'C', 'O', 'CB', '',    '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'NE',  'CZ',  'NH1', 'NH2', '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'ND2', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'OD1', 'OD2', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'SG',  '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'NE2', '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'OE1', 'OE2', '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', '',   '',    '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'ND1', 'CD2', 'CE1', 'NE2', '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  'CE',  'NZ',  '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'SD',  'CE',  '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD',  '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'OG',  '',    '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '',    '',    '',    '',    '',    '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
    ['N', 'CA', 'C', 'O', 'CB', 'CG',  'CD1', 'CD2', 'CE1', 'CE2', 'CZ',  'OH',  '',    ''],
    ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '',    '',    '',    '',    '',    '',    ''],
]


class AntibodyComplexDataset():

    def __init__(self, jsonl_file, cdr_type, L_target):
        self.data = []
        with open(jsonl_file) as f:
            all_lines = f.readlines()
            for line in tqdm(all_lines):
                entry = json.loads(line)
                
                #print(entry['antibody_coords'])
                assert len(entry['antibody_coords']) == len(entry['antibody_seq'])
                assert len(entry['antigen_coords']) == len(entry['antigen_seq'])
                #if entry['antibody_cdr'].count(cdr_type) <= 4:
                #    continue
                
                # paratope region
                cdr_type = str(1)
                surface = torch.tensor(
                        [i for i,v in enumerate(entry['antibody_cdr']) if v in cdr_type]
                )
                entry['binder_surface'] = surface
                
                entry['binder_seq'] = ''.join([entry['antibody_seq'][i] for i in surface.tolist()])
                
                entry['binder_coords'] = torch.tensor(entry['antibody_coords'])[surface]
                entry['binder_atypes'] = torch.tensor(
                        [[ATOM_TYPES.index(a) for a in RES_ATOM14[ALPHABET.index(s)]] for s in entry['binder_seq']]
                )
                mask = (entry['binder_coords'].norm(dim=-1) > 1e-6).long()
                entry['binder_atypes'] *= mask

                # Create target
                entry['target_seq'] = entry['antigen_seq']
                entry['target_coords'] = torch.tensor(entry['antigen_coords'])
                entry['target_atypes'] = torch.tensor(
                        [[ATOM_TYPES.index(a) for a in RES_ATOM14[ALPHABET.index(s)]] for s in entry['target_seq']]
                )
                mask = (entry['target_coords'].norm(dim=-1) > 1e-6).long()
                entry['target_atypes'] *= mask

                # Find target surface
                dist, _ = full_square_dist(
                        entry['target_coords'][None,...], 
                        entry['binder_coords'][None,...], 
                        entry['target_atypes'][None,...], 
                        entry['binder_atypes'][None,...], 
                        contact=True
                )
                K = min(len(dist[0]), L_target)
                epitope = dist[0].amin(dim=-1).topk(k=K, largest=False).indices
                entry['target_surface'] = torch.sort(epitope).values

                if len(entry['binder_coords']) > 4 and len(entry['target_coords']) > 4 and entry['antibody_cdr'].count('001') <= 1:
                    self.data.append(entry)
                #ipdb.set_trace()
                 #   print(entry) 
                # print("entry binder_coords{}".format(entry['binder_coords'].shape))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class ComplexLoader():

    def __init__(self, dataset, batch_tokens):
        self.dataset = dataset
        self.size = len(dataset)
        self.lengths = [len(dataset[i]['binder_seq']) for i in range(self.size)]
        self.batch_tokens = batch_tokens
        sorted_ix = np.argsort(self.lengths)

        # Cluster into batches of similar sizes
        clusters, batch = [], []
        for ix in sorted_ix:
            size = self.lengths[ix]
            batch.append(ix)
            if size * (len(batch) + 1) > self.batch_tokens:
                clusters.append(batch)
                batch = []

        self.clusters = clusters
        if len(batch) > 0:
            clusters.append(batch)

    def __len__(self):
        return len(self.clusters)

    def __iter__(self):
        np.random.shuffle(self.clusters)
        for b_idx in self.clusters:
            batch = [self.dataset[i] for i in b_idx]
            yield batch


def make_batch_from_seq(batch):
    B = len(batch)
    L_max = max([len(seq) for seq in batch])
    S = np.zeros([B, L_max], dtype=np.int32)
    mask = np.zeros([B, L_max], dtype=np.float32)

    for i,seq in enumerate(batch):
        l = len(seq)
        indices = np.asarray([ALPHABET.index(a) for a in seq], dtype=np.int32)
        S[i, :l] = indices
        mask[i, :l] = 1.

    S = torch.from_numpy(S).long().cuda()
    mask = torch.from_numpy(mask).float().cuda()
    return S, mask


def featurize(batch, name):
    B = len(batch)
    L_max = max([len(b[name + "_seq"]) for b in batch])
    X = torch.zeros([B, L_max, 14, 3])
    S = torch.zeros([B, L_max]).long()
    A = torch.zeros([B, L_max, 14]).long()
    V = torch.zeros([B, L_max, 12])

    # Build the batch
    for i, b in enumerate(batch):
        l = len(b[name + '_seq'])
        X[i,:l] = b[name + '_coords']
        A[i,:l] = b[name + '_atypes']
        V[i,:l] = b[name + '_dihedrals'] if name + '_dihedrals' in b else 0
        indices = torch.tensor([ALPHABET.index(a) for a in b[name + '_seq']])
        S[i,:l] = indices
    return X.cuda(), S.cuda(), A.cuda(), V.cuda()


def make_batch(batch):
    target = featurize(batch, 'target')
    binder = featurize(batch, 'binder')
    surface = ([b['binder_surface'] for b in batch], [b['target_surface'] for b in batch])
    return binder, target, surface
