import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import time

import cv2
from tqdm import tqdm
import os 
import numpy as np

def train_one_top_seg(model , optimizer , imgs , targets , do_predict_hmap  ):
    loss_fn_seg = nn.NLLLoss2d( ignore_index = -100 )

    model.train()
    
    imgs = list( map( lambda x :Variable(torch.FloatTensor(x.float() )).cuda() , imgs   ))
    
    target_seg = targets[0].cuda()
    mask = targets[1 ]


    optimizer.zero_grad()

    output_seg = model(imgs)
    if do_predict_hmap : # after
        mask = mask.cuda()
        target_hmap = targets[2].cuda()
        output_seg , output_hmap = output_seg
        loss_fn_l2 = nn.MSELoss()
        
        aa = torch.flatten(output_hmap*mask,1)
        bb = torch.flatten(target_hmap*mask,1) 
#         print( "shaaa" , aa.shape , bb.shape)
        loss_l2 = loss_fn_l2.forward(aa , bb )
    
#     mask.detach()
    
#     target_seg[mask] = -100 

#     dbg[-1] = target_seg

    loss_seg = loss_fn_seg.forward(output_seg , target_seg  )
    
    loss = loss_seg 
    if do_predict_hmap: # after loss = loss_seg 
        loss += loss_l2

    loss.backward()
    optimizer.step()

    return [loss.item() ]







def adjust_learning_rate(optimizer, epoch):
    if epoch <= 200:
       lr = 0.001
    else:
       lr = 0.0001
    print(lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

        
total_iters  = 0
        
def train_epoch(model , optimizer , data_loader , save_path , epoch , do_predict_hmap , sanity_test ):
    adjust_learning_rate(optimizer,epoch)
    global total_iters
    
    all_losses = []

    for batch_idx, (imgCrp , targets ) in enumerate(data_loader):
        start_time = time.time() 
        
        if sanity_test:
            if batch_idx >= 3:
                break

        losses  = train_one_top_seg(model=model , optimizer=optimizer,  imgs=imgCrp , targets=targets ,do_predict_hmap=do_predict_hmap )
                
        
        all_losses.append( losses[0] )
        
        print('Epoch' , epoch , 'Iter' ,batch_idx, "loss" , losses , "time" , time.time() - start_time)
        total_iters += 1
        
        if total_iters%400 == 0:
            savefilename = save_path +'model_every_400.tar'
            torch.save({
                'epoch': epoch,
                'state_dict': model.state_dict()
            }, savefilename)
            print("saved " , savefilename)
                
    return np.mean(np.array(all_losses))






def save_test_preds(model , data_loader , save_path, sanity_test  ):
    

    save_path_imgs = save_path + "test_outs/"
    
    if not os.path.exists(save_path_imgs ):
        os.mkdir( save_path_imgs )

    i = 0
    for batch_idx, (imgCrp , targets ) in tqdm(enumerate(data_loader)):
        
        if sanity_test:
            if batch_idx >= 3:
                break
                
                
        imgCrp = list( map( lambda x :Variable(torch.FloatTensor(x.float() )).cuda() , imgCrp  ))
        output_seg = model(imgCrp )
        if len(output_seg) == 2:
            output_seg , output_hmap = output_seg 
            imm2 = (output_hmap[0].cpu().detach().numpy()*1000).astype("uint16")
            ofn = save_path_imgs + str(i)+".hmap.png"
            cv2.imwrite( ofn  , imm2  )
            assert os.path.exists(ofn )
        imm = output_seg[0].argmax(0).cpu().numpy().astype("uint8")
        cv2.imwrite( save_path_imgs + str(i)+".png" , imm  )
        i += 1

    
def is_degenerate( model , data_loader , n_iter=100 ):
    i = 0
    count_vec = np.zeros((300))
    im_size = 0 
    for batch_idx, (imgCrp , targets ) in tqdm(enumerate(data_loader)):
        
        imgCrp = list( map( lambda x :Variable(torch.FloatTensor(x.float() )).cuda() , imgCrp  ))
        
        output_seg = model(imgCrp )
        if len(output_seg) == 2:
            output_seg , output_hmap = output_seg 
        imm = output_seg[0].argmax(0).cpu().numpy().astype("uint8")
        im_size = imm.shape[0]
        count_vec += np.bincount( imm.flatten() , minlength=300 )
        i += 1 
        
        if i > n_iter :
            break
    
    count_vec.sort()
    mx = count_vec[-1]
    mx2 =  count_vec[-2]
        
    
    if mx > 0.75*float(im_size*im_size*n_iter):
        print("mmm" , mx , mx2 )
        return True
    else:
        return False
    
    