import os
import matplotlib.pyplot as plt
import sys
sys.path.append(".")
import glob
import yaml
import subprocess
import pdb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from Data.GM12878_DataModule import GM12878Module
from Data.K562_DataModule    import K562Module
from Data.IMR90_DataModule   import IMR90Module

RES        = 10000
PIECE_SIZE = 256

#Load Models
from Models.hicbridge  import Unet, GaussianDiffusion
import Models.VEHiCLE_Module as vehicle
import Models.deephic as deephic
import Models.hicplus as hicplus
import Models.hicsr   as hicsr

model_hicbridge = Unet(
    dim = 64,
    dim_mults = (1, 1, 2, 2, 4, 4),
    channels = 1,
    self_condition= False
)

diffusion = GaussianDiffusion(
    model_hicbridge,
    image_size = 256,
    beta_schedule = 'linear',
    timesteps = 1000,   # number of steps
    indi = True,
    objective = 'pred_x0',
    noise_schedule = 'brownian',
    indi_step_size = 1000,
    loss_type = 'l1'
)

model_hicbridge.load_state_dict(torch.load("Trained_Models/HiCBridge+.ckpt"))
model_hicbridge.eval()

devices = 'cuda:0'
model_hicbridge.to(device= devices)
diffusion.to(device = devices)

#VeHICLE
vehicleModel  = vehicle.GAN_Model()
model_vehicle = vehicleModel.load_from_checkpoint("Trained_Models/vehicle_gan.ckpt")

#HiCPlus
model_hicplus = hicplus.Net()
model_hicplus.load_state_dict(torch.load("Trained_Models/HiCPLUS.ckpt"))

#HiCSR
model_hicsr= hicsr.Generator(num_res_blocks=15)
model_hicsr.load_state_dict(torch.load("Trained_Models/HiCSR.ckpt"))
model_hicsr.eval()

#DeepHiC
model_deephic = deephic.Generator(scale_factor=1, in_channel=1, resblock_num=5)
model_deephic.load_state_dict(torch.load("Trained_Models/Deephic.ckpt"))


if not os.path.isdir("hicqc_inputs"):
   os.mkdir("hicqc_inputs")


for CHRO in [4, 14,16,20]:

    hicbridge_hic = open("hicqc_inputs/hicbridge_"+str(CHRO), 'w')
    vehicle_hic = open("hicqc_inputs/vehicle_"+str(CHRO), 'w')
    hicsr_hic = open("hicqc_inputs/hicsr_"+str(CHRO), 'w')
    deephic_hic = open("hicqc_inputs/deephic_"+str(CHRO), 'w')
    hicplus_hic = open("hicqc_inputs/hicplus_"+str(CHRO), 'w')
    original_hic = open("hicqc_inputs/original_"+str(CHRO), 'w')
    down_hic = open("hicqc_inputs/down_"+str(CHRO), 'w')
    bins_file = open("hicqc_inputs/bins_"+str(CHRO)+".bed",'w')


    dm_test = GM12878Module(batch_size=1, res=RES, piece_size=PIECE_SIZE)
    # dm_test = K562Module(batch_size=1, res=10000, piece_size=256)
    # dm_test = IMR90Module(batch_size=1, res=10000, piece_size=256)

    dm_test.prepare_data()
    dm_test.setup(stage=CHRO)

    for s, sample in enumerate(dm_test.test_dataloader()):
        print(str(s)+"/"+str(dm_test.test_dataloader().dataset.data.shape[0]))
        data, target = sample
        downs   = data[0][0]
        target  = target[0][0]

        #Pass through HicPlus
        hicplus_out_small = torch.zeros((PIECE_SIZE, PIECE_SIZE))
        for i in range(0, PIECE_SIZE-40, 28):
            for j in range(0, PIECE_SIZE-40, 28):
                temp                            = data[:,:,i:i+40, j:j+40]
                hicplus_out_small[i+6:i+34, j+6:j+34] =  model_hicplus(temp)
        hicplus_out_small = hicplus_out_small.detach()[6:-6, 6:-6]


        #Pass through Deephic
        deephic_out_small = torch.zeros((PIECE_SIZE, PIECE_SIZE))
        for i in range(0, PIECE_SIZE-40, 28):
            for j in range(0, PIECE_SIZE -40, 28):
                temp                            = data[:,:,i:i+40, j:j+40]
                deephic_out_small[i+6:i+34, j+6:j+34] = model_deephic(temp)[:,:,6:34, 6:34]
        deephic_out_small = deephic_out_small.detach()[6:-6,6:-6]

        #Pass through HiCSR
        hicsr_out_small = torch.zeros((PIECE_SIZE, PIECE_SIZE))
        for i in range(0, PIECE_SIZE-40, 28):
            for j in range(0, PIECE_SIZE-40, 28):
                temp                          = data[:,:,i:i+40, j:j+40]
                hicsr_out_small[i+6:i+34, j+6:j+34] = model_hicsr(temp)
        hicsr_out_small = hicsr_out_small.detach()[6:-6, 6:-6]
        hicsr_out_small = torch.clamp(hicsr_out_small,0, 100000000)

        #PASS through hicbridge
        condition = data.to(devices)
        hicbridge_out = diffusion.sample(batch_size= 1, condition = condition).detach().cpu()[0][0][6:-6,6:-6]
        
        #PASS through VeHICLE
        vehicle_out = model_vehicle(condition).detach()[0][0]

        downs   = downs[6:-6,6:-6]
        target  = target[6:-6,6:-6]

        for i in range(0, 224):     
            if s == 0:
                bina = (i+6)*RES
                bina_end = bina+RES
                bins_file.write(str(CHRO)+"\t"+str(bina)+"\t"+str(bina_end)+"\t"+str(bina)+"\n")
                
                for j in range(i, 224):  
                    binb = (j+6)*RES
                    down_hic.write(    str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(downs[i,j]*100))+"\n")     
                    original_hic.write(str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(target[i,j]*100))+"\n") 

                    hicplus_hic.write( str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(hicplus_out_small[i,j]*100))+"\n") 
                    deephic_hic.write( str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(deephic_out_small[i,j]*100))+"\n") 
                    hicsr_hic.write(   str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(hicsr_out_small[i,j]*100))+"\n") 
                    vehicle_hic.write( str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(vehicle_out[i,j]*100))+"\n") 
                    hicbridge_hic.write(    str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(hicbridge_out[i,j]*100))+"\n")
            else:
                for j in range(224 - 50, 224):  
                    if i == j:
                        bina = (50*s+i+6)*RES
                        bina_end = bina+RES
                        bins_file.write(str(CHRO)+"\t"+str(bina)+"\t"+str(bina_end)+"\t"+str(bina)+"\n")
                    if i > j:
                        continue
                    bina = (50*s+i+6)*RES
                    binb = (50*s+j+6)*RES
                    down_hic.write(    str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(downs[i,j]*100))+"\n")     
                    original_hic.write(str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(target[i,j]*100))+"\n") 

                    hicplus_hic.write( str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(hicplus_out_small[i,j]*100))+"\n") 
                    deephic_hic.write( str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(deephic_out_small[i,j]*100))+"\n") 
                    hicsr_hic.write(   str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(hicsr_out_small[i,j]*100))+"\n") 
                    vehicle_hic.write( str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(vehicle_out[i,j]*100))+"\n") 
                    hicbridge_hic.write(    str(CHRO)+"\t"+str(bina)+"\t"+str(CHRO)+"\t"+str(binb)+"\t"+str(int(hicbridge_out[i,j]*100))+"\n")


    down_hic.close()
    bins_file.close()
    original_hic.close()
    hicplus_hic.close()
    deephic_hic.close()
    hicsr_hic.close()
    vehicle_hic.close()
    hicbridge_hic.close()

    subprocess.run("gzip hicqc_inputs/vehicle_"+str(CHRO),  shell=True)
    subprocess.run("gzip hicqc_inputs/hicsr_"+str(CHRO),    shell=True)
    subprocess.run("gzip hicqc_inputs/deephic_"+str(CHRO),  shell=True)
    subprocess.run("gzip hicqc_inputs/hicplus_"+str(CHRO),  shell=True)
    subprocess.run("gzip hicqc_inputs/original_"+str(CHRO), shell=True)
    subprocess.run("gzip hicqc_inputs/down_"+str(CHRO),     shell=True)
    subprocess.run("gzip hicqc_inputs/hicbridge_"+str(CHRO),     shell=True)
    subprocess.run("gzip hicqc_inputs/bins_"+str(CHRO)+".bed",     shell=True)


    tool_names   = ['hicplus', 'deephic', 'hicsr', 'vehicle', 'hicbridge', 'down']
    BASE_STR = 'hicqc_inputs/'
    sample_files = [
                'hicqc_inputs/metric_hicplus_'+str(CHRO)+".samples",
                'hicqc_inputs/metric_deephic_'+str(CHRO)+".samples",
                'hicqc_inputs/metric_hicsr_'+str(CHRO)+".samples",
                'hicqc_inputs/metric_vehicle_'+str(CHRO)+".samples",
                'hicqc_inputs/metric_hicbridge_'+str(CHRO)+".samples",
                'hicqc_inputs/metric_down_'+str(CHRO)+".samples"
                ]

    pair_files  = [
                'hicqc_inputs/metric_hicplus_'+str(CHRO)+".pairs",
                'hicqc_inputs/metric_deephic_'+str(CHRO)+".pairs",
                'hicqc_inputs/metric_hicsr_'+str(CHRO)+".pairs",
                'hicqc_inputs/metric_vehicle_'+str(CHRO)+".pairs",
                'hicqc_inputs/metric_hicbridge_'+str(CHRO)+".pairs",
                'hicqc_inputs/metric_down_'+str(CHRO)+".pairs"
                ]

    for tool_name, sample_fn, pair_fn in zip(tool_names, sample_files, pair_files):
        hic_metric_sample = open(sample_fn, 'w')
        hic_metric_pair   = open(pair_fn, 'w')
        SAMPLE_STRING="original     "+BASE_STR+"original_"+str(CHRO)+".gz\n"+str(tool_name)+"    "+BASE_STR+str(tool_name)+"_"+str(CHRO)+".gz"
        PAIR_STRING  = "original\t"+str(tool_name)
        hic_metric_sample.write(SAMPLE_STRING)
        hic_metric_pair.write(PAIR_STRING)
        hic_metric_sample.close()
        hic_metric_pair.close()

# For use with "hicqc_inputs/'model'_'chr'.gz" to Fit-Hi-C,
# Fragement creates a format used for Fit-Hi-C with 'model'_'chr'.gz
