import sys
sys.path.append(".")
import numpy as np

import matplotlib.pyplot as plt

import torch

from Models.hicbridge  import Unet, GaussianDiffusion
import Models.VEHiCLE_Module as vehicle
import Models.hicsr   as hicsr
import Models.deephic as deephic
import Models.hicplus as hicplus

from Utils import vision_metrics as vm
from Data.GM12878_DataModule import GM12878Module
from Data.K562_DataModule import K562Module
from Data.IMR90_DataModule import IMR90Module

#load data

dm_test = GM12878Module(batch_size=1, res=10000, piece_size=256)
# dm_test = K562Module(batch_size=1, res=10000, piece_size=256)
# dm_test = IMR90Module(batch_size=1, res=10000, piece_size=256)

dm_test.prepare_data()
dm_test.setup(stage=4)

ds     = torch.from_numpy(dm_test.test_dataloader().dataset.data[8:9])
target = torch.from_numpy(dm_test.test_dataloader().dataset.target[8:9])

model_hicbridge = Unet(
    dim = 64,
    dim_mults = (1, 1, 2, 2, 4, 4),
    channels = 1,
    self_condition= False
)

diffusion = GaussianDiffusion(
    model_hicbridge,
    image_size = 256,
    beta_schedule = 'linear',
    timesteps = 1000, 
    indi = True,
    objective = 'pred_x0',
    noise_schedule = 'brownian',
    indi_step_size = 1000,
    loss_type = 'l1'
)

model_hicbridge.load_state_dict(torch.load("Trained_Models/HiCBridge+.ckpt"))
model_hicbridge.eval()

model_hicplus = hicplus.Net()
model_hicplus.load_state_dict(torch.load("Trained_Models/HiCPlus.ckpt"))
model_hicplus.eval()

model_hicsr   = hicsr.Generator(num_res_blocks=15)
model_hicsr.load_state_dict(torch.load("Trained_Models/HiCSR.ckpt"))
model_hicsr.eval()

model_deephic = deephic.Generator(scale_factor=1, in_channel=1, resblock_num=5)
model_deephic.load_state_dict(torch.load("Trained_Models/Deephic.ckpt"))
model_hicsr.eval()

vehicleModel  = vehicle.GAN_Model()
model_vehicle = vehicleModel.load_from_checkpoint("Trained_Models/vehicle_gan.ckpt", map_location='cpu')

#pass through models
devices = 'cuda:0'
lowres_out  = ds[0][0][6:-6,6:-6]
target_out  = target[0][0][6:-6,6:-6]

model_hicbridge.to(device= devices)
diffusion.to(device = devices)
condition = ds.to(devices)

hicbridge_out = diffusion.sample(batch_size= 1, condition = condition).detach().cpu()[0][0][6:-6,6:-6]

vehicle_out = model_vehicle(ds).detach()[0][0] # 244 * 244

FULL_RES    = 256
hicplus_out = torch.zeros((FULL_RES,FULL_RES))
for i in range(0, FULL_RES-40, 28):
    for j in range(0, FULL_RES-40, 28):
        temp                  = ds[:,:,i:i+40, j:j+40]
        hicplus_out[i+6:i+34, j+6:j+34] =  model_hicplus(temp) # 28 * 28
hicplus_out = hicplus_out.detach()[6:-6,6:-6]

hicsr_out = torch.zeros(FULL_RES, FULL_RES)
for i in range(0, FULL_RES-40, 28):
    for j in range(0, FULL_RES-40, 28):
        temp                          = ds[:,:,i:i+40, j:j+40]
        hicsr_out[i+6:i+34, j+6:j+34] = model_hicsr(temp).detach()
hicsr_out = hicsr_out[6:-6,6:-6]

deephic_out = torch.zeros((FULL_RES, FULL_RES))
for i in range(0, FULL_RES-40, 28):
    for j in range(0, FULL_RES -40, 28):
        temp                        = ds[:,:,i:i+40, j:j+40]
        deephic_out[i+6:i+34, j+6:j+34] = model_deephic(temp)[:,:,6:34, 6:34] # 40 * 40
deephic_out = deephic_out.detach()[6:-6,6:-6]


# #show comparison plots
fig, ax = plt.subplots(2,7)
for i in range(0, 2):
    for j in range(0,7):
       ax[i,j].set_xticks([])
       ax[i,j].set_yticks([])

lowres_out = lowres_out[:-20,:-20]
target_out = target_out[:-20,:-20]
hicbridge_out   = hicbridge_out[:-20,:-20]
hicplus_out= hicplus_out[:-20,:-20]
hicsr_out  = hicsr_out[:-20,:-20]
deephic_out= deephic_out[:-20,:-20]
vehicle_out= vehicle_out[:-20,:-20]


ax[0,0].imshow(lowres_out,  cmap="Reds", vmin=0, vmax=1)
ax[0,1].imshow(target_out,  cmap="Reds", vmin=0, vmax=1)
ax[0,2].imshow(hicbridge_out, cmap="Reds", vmin=0, vmax=1)
ax[0,3].imshow(hicplus_out, cmap="Reds", vmin=0, vmax=1)
ax[0,4].imshow(hicsr_out, cmap="Reds", vmin=0, vmax=1)
ax[0,5].imshow(deephic_out, cmap="Reds", vmin=0, vmax=1)
ax[0,6].imshow(vehicle_out, cmap="Reds", vmin=0, vmax=1)

ax[0,0].set_title("Low-resolution" ,fontsize=5)
ax[0,1].set_title("High-resolution",fontsize=5)
ax[0,2].set_title("hicbridge",fontsize=5)
ax[0,3].set_title("HiCPlus",fontsize=5)
ax[0,4].set_title("HiCSR",fontsize=5)
ax[0,5].set_title("DeepHiC",fontsize=5)
ax[0,6].set_title("VEHiCLE",fontsize=5)

ax[1,0].imshow(lowres_out[40:140,40:140] - target_out[40:140, 40:140], cmap="Greys", vmin=-0.3, vmax=0.3)
ax[1,1].imshow(target_out[40:140,40:140]- target_out[40:140, 40:140], cmap="Greys", vmin=-0.3, vmax=0.3)
ax[1,2].imshow(hicbridge_out[40:140,40:140] - target_out[40:140, 40:140], cmap="Greys", vmin=-0.3, vmax=0.3)
ax[1,3].imshow(hicplus_out[40:140,40:140]- target_out[40:140, 40:140], cmap="Greys", vmin=-0.3, vmax=0.3)
ax[1,4].imshow(hicsr_out[40:140,40:140]- target_out[40:140, 40:140], cmap="Greys", vmin=-0.3, vmax=0.3)
ax[1,5].imshow(deephic_out[40:140,40:140] - target_out[40:140, 40:140], cmap="Greys", vmin=-0.3, vmax=0.3)
ax[1,6].imshow(vehicle_out[40:140,40:140] - target_out[40:140, 40:140], cmap="Greys", vmin=-0.3, vmax=0.3)

ax[1,0].set_title(str(torch.linalg.norm(lowres_out - target_out)) ,fontsize=5)
ax[1,1].set_title(str(torch.linalg.norm(target_out - target_out)) ,fontsize=5)
ax[1,2].set_title(str(torch.linalg.norm(hicbridge_out - target_out)) ,fontsize=5)
ax[1,3].set_title(str(torch.linalg.norm(hicplus_out - target_out)) ,fontsize=5)
ax[1,4].set_title(str(torch.linalg.norm(hicsr_out - target_out)) ,fontsize=5)
ax[1,5].set_title(str(torch.linalg.norm(deephic_out - target_out)) ,fontsize=5)
ax[1,6].set_title(str(torch.linalg.norm(vehicle_out - target_out)) ,fontsize=5)

plt.savefig('standard_vision_metrics.png', dpi=400, bbox_inches='tight')
plt.show()

# ========================================================================================

v_m ={}

for chro in [20]:
    print(chro)
    # print()
    print("hicbridge")
    visionMetrics = vm.VisionMetrics()
    visionMetrics.setDataset(chro)
    v_m[chro, 'hicbridge']=visionMetrics.getMetrics(model=diffusion, spliter="hicbridge", device = devices)

    print("vehicle")
    visionMetrics = vm.VisionMetrics()
    visionMetrics.setDataset(chro)
    v_m[chro, 'vehicle']=visionMetrics.getMetrics(model=model_vehicle, spliter="vehicle")

    print("hicplus")
    visionMetrics = vm.VisionMetrics()
    visionMetrics.setDataset(chro)
    v_m[chro, 'hicplus']=visionMetrics.getMetrics(model=model_hicplus, spliter="hicplus")

    print("HiCSR")
    visionMetrics = vm.VisionMetrics()
    visionMetrics.setDataset(chro)
    v_m[chro, 'hicsr']=visionMetrics.getMetrics(model=model_hicsr, spliter="hicsr") 

    print("deephic")
    visionMetrics = vm.VisionMetrics()
    visionMetrics.setDataset(chro)
    v_m[chro, 'deephic']=visionMetrics.getMetrics(model=model_deephic, spliter="deephic")





