import warnings
import sys
sys.path.append(".")

from Models.hicbridge  import Unet, GaussianDiffusion, Trainer
import Models.VEHiCLE_Module as vehicle
import Models.hicsr   as hicsr
import Models.deephic as deephic
import Models.hicplus as hicplus

import pdb
import os
import tmscoring
import glob
import subprocess
import pdb
import numpy as np
import matplotlib.pyplot as plt
import torch

from Data.GM12878_DataModule import GM12878Module
from Data.K562_DataModule    import K562Module
from Data.IMR90_DataModule   import IMR90Module

PIECE_SIZE = 256

def buildFolders():
    if not os.path.exists('3D_Mod'):
        os.makedirs('3D_Mod')
    if not os.path.exists('3D_Mod/Constraints'):
        os.makedirs('3D_Mod/Constraints')
    if not os.path.exists('3D_Mod/output'):
        os.makedirs('3D_Mod/output')
    if not os.path.exists('3D_Mod/Parameters'):
        os.makedirs('3D_Mod/Parameters')

def convertChroToConstraints(chro,
                            cell_line="GM12878",
                            res=10000):

    if cell_line=="GM12878":
        dm_test = GM12878Module(batch_size=1, res=res, piece_size=PIECE_SIZE)
    elif cell_line=="K562":
        dm_test = K562Module(batch_size=1, res=res, piece_size=PIECE_SIZE)
    elif cell_line=="IMR90":
        dm_test = IMR90Module(batch_size=1, res=res, piece_size=PIECE_SIZE)
    dm_test.prepare_data()
    dm_test.setup(stage=chro)
    
    model_hicbridge = Unet(
        dim = 64,
        dim_mults = (1, 1, 2, 2, 4, 4),
        channels = 1,
        self_condition= False
    )

    diffusion = GaussianDiffusion(
        model_hicbridge,
        image_size = 256,
        beta_schedule = 'linear',
        timesteps = 1000, 
        indi = True,
        objective = 'pred_x0',
        noise_schedule = 'brownian',
        indi_step_size = 1000,
        loss_type = 'l1'
    )   
    model_hicbridge.load_state_dict(torch.load("Trained_Models/HiCBridge+.ckpt"))
    model_hicbridge.eval()

    devices = 'cuda:0'
    model_hicbridge.to(device= devices)
    diffusion.to(device = devices)
    
    model_hicplus = hicplus.Net()
    model_hicplus.load_state_dict(torch.load("Trained_Models/HiCPlus.ckpt"))
    model_hicplus.eval()

    model_hicsr   = hicsr.Generator(num_res_blocks=15)
    model_hicsr.load_state_dict(torch.load("Trained_Models/HiCSR.ckpt"))
    model_hicsr.eval()

    model_deephic = deephic.Generator(scale_factor=1, in_channel=1, resblock_num=5)
    model_deephic.load_state_dict(torch.load("Trained_Models/Deephic.ckpt"))
    model_hicsr.eval()
    
    WEIGHT_PATH    = "Trained_Models/vehicle_gan.ckpt"
    vehicleModel   = vehicle.GAN_Model()
    model_vehicle  = vehicleModel.load_from_checkpoint(WEIGHT_PATH)
    

    NUM_ENTRIES = dm_test.test_dataloader().dataset.data.shape[0]
    for s, sample in enumerate(dm_test.test_dataloader()):
        if s % 5 ==0:
            print(str(s)+"/"+str(NUM_ENTRIES))
            data, target = sample
            condition = data.to(devices)            
            hicbridge_output = diffusion.accelated_sample(num_timesteps= 250, condition = condition).detach().cpu()[0][0][6:-6,6:-6]
            hicplus_output = torch.zeros((PIECE_SIZE,PIECE_SIZE))
            hicsr_output = torch.zeros(PIECE_SIZE, PIECE_SIZE)
            deephic_output = torch.zeros((PIECE_SIZE, PIECE_SIZE))
            for i in range(0, PIECE_SIZE-40, 28):
                for j in range(0, PIECE_SIZE-40, 28):
                    temp = data[:,:,i:i+40, j:j+40]
                    hicplus_output[i+6:i+34, j+6:j+34] =  model_hicplus(temp) 
                    hicsr_output[i+6:i+34, j+6:j+34] = model_hicsr(temp).detach()
                    deephic_output[i+6:i+34, j+6:j+34] = model_deephic(temp)[:,:,6:34, 6:34]
            hicplus_output = hicplus_output.detach()[6:-6,6:-6]     
            hicsr_output = hicsr_output[6:-6,6:-6]
            deephic_output = deephic_output.detach()[6:-6,6:-6]
            vehicle_output =  model_vehicle(condition).detach()[0][0]
            
            # thresholding 0 to 1
            torch.clip(hicbridge_output, min=0, max=1, out=hicbridge_output)
            torch.clip(hicplus_output, min=0, max=1, out=hicplus_output)
            torch.clip(hicsr_output, min=0, max=1, out=hicsr_output)
            torch.clip(deephic_output, min=0, max=1, out=deephic_output)
            torch.clip(vehicle_output, min=0, max=1, out=vehicle_output)
            torch.clip(data, min=0, max=1, out=data)

            data   = data[0][0][6:-6, 6:-6]
            target = target[0][0][6:-6, 6:-6]

            target_const_name   = "3D_Mod/Constraints/chro_"+str(chro)+"_target_"+str(s)+"_"
            data_const_name     = "3D_Mod/Constraints/chro_"+str(chro)+"_data_"+str(s)+"_"
            hicbridge_const_name = "3D_Mod/Constraints/chro_"+str(chro)+"_hicbridge_"+str(s)+"_"
            hicplus_const_name = "3D_Mod/Constraints/chro_"+str(chro)+"_hicplus_"+str(s)+"_"
            deephic_const_name ="3D_Mod/Constraints/chro_"+str(chro)+"_deephic_"+str(s)+"_"
            hicsr_const_name = "3D_Mod/Constraints/chro_"+str(chro)+"_hicsr_"+str(s)+"_"
            vehicle_const_name  = "3D_Mod/Constraints/chro_"+str(chro)+"_vehicle_"+str(s)+"_"
            
            target_constraints  = open(target_const_name, 'w')
            data_constraints    = open(data_const_name, 'w')
            hicbridge_constraints = open(hicbridge_const_name, 'w')
            hicplus_constraints = open(hicplus_const_name, 'w')
            deephic_constraints = open(deephic_const_name, 'w')
            hicsr_constraints = open(hicsr_const_name, 'w')
            vehicle_constraints = open(vehicle_const_name, 'w')
            
            for i in range(0, data.shape[0]):
                for j in range(i, data.shape[1]):
                    data_constraints.write(str(i)+"\t"+str(j)+"\t"+str(data[i,j].item())+"\n")
                    target_constraints.write(str(i)+"\t"+str(j)+"\t"+str(target[i,j].item())+"\n")
                    hicbridge_constraints.write(str(i)+"\t"+str(j)+"\t"+str(hicbridge_output[i,j].item())+"\n")
                    hicplus_constraints.write(str(i)+"\t"+str(j)+"\t"+str(hicplus_output[i,j].item())+"\n")
                    deephic_constraints.write(str(i)+"\t"+str(j)+"\t"+str(deephic_output[i,j].item())+"\n")
                    hicsr_constraints.write(str(i)+"\t"+str(j)+"\t"+str(hicsr_output[i,j].item())+"\n")
                    vehicle_constraints.write(str(i)+"\t"+str(j)+"\t"+str(vehicle_output[i,j].item())+"\n")
            target_constraints.close()
            data_constraints.close()
            hicbridge_constraints.close()
            hicplus_constraints.close()
            deephic_constraints.close()
            hicsr_constraints.close()
            vehicle_constraints.close()

def buildParameters(chro,
                cell_line="GM12878",
                res=10000):
    constraints  = glob.glob("3D_Mod/Constraints/chro_"+str(chro)+"_*")
    for constraint in  constraints:
        # for linux 
        suffix = constraint.split("/")[-1]
        # for windows
        # suffix = constraint.split("\\")[-1]
        stri = """NUM = 3\r
OUTPUT_FOLDER = 3D_Mod/output/\r
INPUT_FILE = """+constraint+"""\r
CONVERT_FACTOR = 0.6\r
VERBOSE = true\r
LEARNING_RATE = 1\r
MAX_ITERATION = 10000\r"""
        param_f = open("3D_Mod/Parameters/"+suffix, 'w')
        param_f.write(stri)
    
JAR_LOCATION = "other_tools/examples/3DMax.jar"
# if not os.path.exists(JAR_LOCATION):
#     subprocess.run("git clone https://github.com/BDM-Lab/3DMax.git other_tools")

def runSegmentParams(chro, position_index):
    for struc in ['data', 'target', 'hicbridge', 'hicplus', 'deephic', 'hicsr', 'vehicle']:
        subprocess.run("java -Xmx5000m -jar "+JAR_LOCATION+" 3D_Mod/Parameters/chro_"+str(chro)+"_"+struc+"_"+str(position_index)+"_", shell=True)

def runParams(chro):
    params = glob.glob("3D_Mod/Parameters/chro_"+str(chro)+"_*")
    for par in params:
        subprocess.run("java -Xmx5000m -jar "+JAR_LOCATION+" "+par, shell=True)

def getSegmentTMScores(chro, position_index):
    data_strucs     = glob.glob("3D_Mod/output/chro_"+str(chro)+"_data_"+str(position_index)+"_*.pdb")
    target_strucs   = glob.glob("3D_Mod/output/chro_"+str(chro)+"_target_"+str(position_index)+"_*.pdb")
    hicbridge_strucs     = glob.glob("3D_Mod/output/chro_"+str(chro)+"_hicbridge_"+str(position_index)+"_*.pdb")
    hicplus_strucs  = glob.glob("3D_Mod/output/chro_"+str(chro)+"_hicplus_"+str(position_index)+"_*.pdb")
    deephic_strucs  = glob.glob("3D_Mod/output/chro_"+str(chro)+"_deephic_"+str(position_index)+"_*.pdb")
    hicsr_strucs    = glob.glob("3D_Mod/output/chro_"+str(chro)+"_hicsr_"+str(position_index)+"_*.pdb")
    vehicle_strucs  = glob.glob("3D_Mod/output/chro_"+str(chro)+"_vehicle_"+str(position_index)+"_*.pdb")
    struc_types      = [data_strucs, target_strucs, hicbridge_strucs, hicplus_strucs , deephic_strucs, hicsr_strucs , vehicle_strucs]
    struc_type_names = ['data_strucs', 'target_strucs','hicbridge_strucs', 'hicplus_strucs' , 'deephic_strucs', 'hicsr_strucs', 'vehicle_strucs'] 

    relative_scores = {'data_strucs':[],
                    'hicbridge_strucs':[],
                    'hicplus_strucs':[], 
                    'deephic_strucs':[],
                    'hicsr_strucs':[],
                    'vehicle_strucs':[]}
    for struc_type, struc_type_name in zip(struc_types, struc_type_names):
        if struc_type_name == 'target_strucs':
            continue
        for i, data_a in enumerate(struc_type):
            for j, data_b in enumerate(target_strucs):   
                alignment = tmscoring.TMscoring(data_a, data_b)
                alignment.optimise()
                indiv_tm  = alignment.tmscore(**alignment.get_current_values())
                relative_scores[struc_type_name].append(indiv_tm)
    return relative_scores

def getTMScores(chro):
    relative_scores = {'data_strucs':[],
                    'hicbridge_strucs':[],
                    'hicplus_strucs':[], 
                    'deephic_strucs':[],
                    'hicsr_strucs':[],
                    'vehicle_strucs':[]}

    getSampleNum = lambda a: a.split("_")[-2]
    for position_index in list(map(getSampleNum, glob.glob("3D_Mod/Parameters/chro_"+str(chro)+"_*"))):
        temp_relative_scores = getSegmentTMScores(chro, position_index)
        for key in temp_relative_scores.keys():
            relative_scores[key].extend(temp_relative_scores[key])

    print("RELATIVE SCORES")
    for key in relative_scores.keys():
        print(key+":\t"+str(np.mean(relative_scores[key])))
    return relative_scores


def viewModels(chro):
    struc_index=100
    models = glob.glob("3D_Mod/output/chro_"+str(chro)+"_*_"+str(struc_index)+"_*.pdb")
    subprocess.run("pymol "+' '.join(models),  shell=True)

def parallelScatter():
    chros = [4,14,16,20]
    relative_data = []
    internal_data = []
    for chro in chros:
        relative = getTMScores(chro)
        for key in relative.keys():
            relative_data.append(relative[key])
    #relative
    fig, ax = plt.subplots()
    bp = ax.boxplot(relative_data, 
            positions=[1,2,3,4,5,6, 8,9,10,11,12,13, 15,16,17,18,19,20, 22,23,24,25,26,27],
            patch_artist=True)
    for b, box in enumerate(bp['boxes']):
        if b%6 ==0:
            box.set(facecolor = 'crimson')
        elif b%6 ==1:
            box.set(facecolor = 'bisque')
        elif b%6 ==2:
            box.set(facecolor = 'dodgerblue')
        elif b%6 ==3:
            box.set(facecolor = 'darkorchid')
        elif b%6 ==4:
            box.set(facecolor = 'forestgreen')
        elif b%6 ==5:
            box.set(facecolor = 'gold')
        else:
            box.set(facecolor = 'black')

    ax.set_xticks([3.5, 10.5, 17.5, 24.5])
    ax.set_xticklabels(['Chro4', 'Chro14', 'Chro16', 'Chro20'])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_title("Relative")
    
    plt.savefig("Relative_TMScore_K562.png", dpi=300)
    plt.show()
    
if __name__ == "__main__":
    buildFolders()
    for chro in [4, 14, 16, 20]:
        convertChroToConstraints(chro, cell_line="K562")
        buildParameters(chro)
        runParams(chro)
    getTMScores(20)
    parallelScatter()
    viewModels(20)
