import os
import copy
import torch
torch.set_float32_matmul_precision('high')
from torch import nn
import torch.nn.functional as F
import argparse
import random
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from torch import  nn
from models.model import ReReIC
from utils.eval_model import eval_model
import cv2
import torchvision.utils as vutils
from lossy_contour_algorithm import get_border_bits
from models.candidate_train import train_with_candidates
from wasserstein_distortion import VGG16WassersteinDistortion

manual_seed=1
def seed_everything(seed=1029):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PATHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(1)
print('seed',manual_seed)

def get_mgrid(w_sidelen,h_sidelen, dim=2):
    
    x = torch.linspace(-1, 1, steps=w_sidelen) 
    y = torch.linspace(-1, 1, steps=h_sidelen)  
    tensors = (x, y) if dim == 2 else (x, ) * dim  
    mgrid = torch.stack(torch.meshgrid(*tensors, indexing='ij'), dim=-1)  
    mgrid = mgrid.unsqueeze(0).permute(0,3,2,1)

    return mgrid


def make_path(path):
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Directory '{path}' created.")
    
    return  0

def loss_to_psnr(loss, max=1):
  return 10*np.log10(max**2/np.asarray(loss))

def get_mask_h_w(mask_path):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    y_indices, x_indices = np.where(mask == 255)
    target_mask = mask == 255

    if len(x_indices) > 0 and len(y_indices) > 0:
        min_x, max_x = x_indices.min(), x_indices.max()
        min_y, max_y = y_indices.min(), y_indices.max()
        width = max_x - min_x + 1
        height = max_y - min_y + 1
        cropped_mask = target_mask[min_y:max_y + 1, min_x:max_x + 1]
    return width, height, torch.from_numpy(cropped_mask).unsqueeze(0).unsqueeze(0)

def mm(mask_path):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  
    target_mask = mask == 255
    target_mask_flat = target_mask.flatten()
    target_mask_tensor = torch.from_numpy(target_mask_flat).bool()

    return target_mask_tensor, torch.from_numpy(target_mask).unsqueeze(0).unsqueeze(0)
def extract_individual_region_masks(mask_path):
    
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  
    _, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)

    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    masks_tensor_list = []

    all_foreground_mask = np.zeros_like(binary, dtype=np.uint8)
    for contour in contours:

        region_mask = np.zeros_like(binary, dtype=np.uint8)
        cv2.drawContours(region_mask, [contour], -1, color=255, thickness=-1)

        all_foreground_mask = cv2.bitwise_or(all_foreground_mask, region_mask)

        
        region_tensor = torch.from_numpy(region_mask > 0).unsqueeze(0).unsqueeze(0)
        masks_tensor_list.append(region_tensor)


    
    background_mask = (all_foreground_mask == 0)  # bool array
    background_tensor = torch.from_numpy(background_mask).unsqueeze(0).unsqueeze(0)
    masks_tensor_list.append(background_tensor)

    print(f"from {mask_path} extract {len(masks_tensor_list)} region mask")
   

    return masks_tensor_list
def train(target_mask_list, model,dataloader, total_steps, total_steps_2,steps_til_summary,img_index,saved_path):
   
    best_psnr=0
    criterion = nn.MSELoss().cuda()
    wd_loss = VGG16WassersteinDistortion().cuda()

    base_params = [p for name, p in model.named_parameters()]
   
    optim = torch.optim.Adam([{'params': base_params, 'lr': args.lr}])
   
    best_optimizer_state = copy.deepcopy(optim.state_dict())
    scheduler = CosineAnnealingLR(optim, T_max=total_steps,eta_min=0.00001,
            last_epoch=-1,)
    best_model = model.get_param()

    optimizer_stage_2 =torch.optim.Adam([p for p in model.parameters() if p.requires_grad],lr=1e-4)
    scheduler_stage_2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_stage_2, mode='min', factor=0.8, patience=20, verbose=True)

    for batch_idx, (img_in,_) in enumerate(dataloader, 0):
        model.train()
        vis_colum=1
        img_in=img_in.cuda()
        batch_size,_,height,width=img_in.shape
        pixels = img_in.permute(0, 2, 3, 1).view(batch_size,-1, 3).cuda()
        
        coords = get_mgrid(width//args.scale,height//args.scale, 2).cuda()
        losses = []
        losses_2 = []

        best_psnr=0
        best_rate=0
        
        initial_noise_param = 2.0
        final_noise_param = 1.0
        initial_temperature = 0.3
        final_temperature = 0.1
        print('start temperature:', initial_temperature,'noise parameter:',initial_noise_param)
        print('end temperature:', final_temperature,'noise parameter:',final_noise_param)

        print("********************Start from stage I")
        best_rd=1000
        patience = 5000
        cnt_record = 0
        for step in range(total_steps):
            if step - cnt_record > patience:
                model.set_param(best_model)
                optim.load_state_dict(best_optimizer_state)
                current_lr = scheduler.state_dict()["_last_lr"][0]
               
                for g in optim.param_groups:
                    g["lr"] = current_lr

                cnt_record = step

            ###stage 1:
            model.train()
           
            model.noise_parameter = initial_noise_param - (step / total_steps) * (initial_noise_param - final_noise_param)
            model.soft_round_temperature = initial_temperature - (step / total_steps) * (initial_temperature - final_temperature)

            model_output,rate = model(coords)   
           
            bits_rate=rate.sum()/(args.all_pix_num)
            
            
            img_out=model_output.view(batch_size,height,width,3).permute(0,3,1,2)
            loss_d,wd_information=wd_loss(img_out, img_in,model.log2_sigma, num_scales=3,saliency=model.saliency_tensor)
           

            loss=args.lambda_rate*bits_rate+loss_d
            losses.append(loss.item())

            optim.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], 1e-1,norm_type=2.0, error_if_nonfinite=False)
          
            optim.step()
            scheduler.step()
            if not (step+1) % steps_til_summary or (step+1==total_steps):
              
                if (loss<best_rd) and (step>0):
                    img_out=model_output.view(batch_size,height,width,3).permute(0,3,1,2)
                    best_rate = bits_rate.item()
                    best_d=loss_d

                  
                    
                    best_rd=loss
                    checkpoint = {
                        'model_state_dict': model.state_dict(),
                        'binary mask': None
                    }
                    best_model = model.get_param()
                    best_optimizer_state = copy.deepcopy(optim.state_dict())
                    cnt_record = step
                    print("Step %d, BEST Distortion: %0.6f, Total loss %0.6f" % (step+1, loss_d.item(),loss),'with its rate', bits_rate.item(), 'latent_bits', rate.sum().item())
                    print(wd_information)
                    if args.rec_flag==1:
                        vis_folder='./rec/'+args.dataset+'_train/'
                        make_path(vis_folder)
                        vutils.save_image(img_out,vis_folder+str(img_index+1)+'_'+str(args.lambda_rate_list)+'.png',nrow=1)
                  
        torch.save(checkpoint, saved_path)
        checkpoints=torch.load(saved_path)
        model.load_state_dict(checkpoints['model_state_dict'])
        ###stage_2
        print("********************going into stage II")
        best_psnr_2=0
        best_rd_2=1000
        for step in range(total_steps_2):
            model.train()
            model.quantizer_type="softround_alone"
            model.quantizer_noise_type="none"
            model.soft_round_temperature=1e-4
           
            model_output,rate = model(coords)   
          
            bits_rate_2=rate.sum()/(args.all_pix_num)
            
            img_out=model_output.view(batch_size,height,width,3).permute(0,3,1,2)
            loss_d_2,_=wd_loss(img_out, img_in,model.log2_sigma, num_scales=3,saliency=model.saliency_tensor)


          
            
            loss_2=args.lambda_rate*bits_rate_2+loss_d_2
            losses_2.append(loss_2.item())
            if not step % steps_til_summary or (step==total_steps_2-1):
                
                if (loss_2<best_rd_2) and (step>0):
                    
                    best_rd_2=loss_2
                    checkpoint = {
                        'model_state_dict': model.state_dict(),
                        'binary mask': None
                    }
                   

                    print('Print rate', bits_rate_2.item())
                    print('latent_bits', rate.sum().item())
                    
                    print("Step %d, BEST Distortion: %0.6f, Total loss %0.6f" % (step, loss_d_2.item(),loss_2),'with its rate', bits_rate_2.item(), 'latent_bits', rate.sum().item())
                    if args.rec_flag==1:
                        vis_folder='./rec/'+args.dataset+'_train/'
                        make_path(vis_folder)
                        vutils.save_image(img_out,vis_folder+str(img_index+1)+'_'+str(args.lambda_rate_list)+'.png',nrow=1)
                    
            optimizer_stage_2.zero_grad()
            loss_2.backward()
            optimizer_stage_2.step()
            scheduler_stage_2.step(loss_2)
            current_lr = optimizer_stage_2.param_groups[0]['lr']
           
            if current_lr < 1e-8:
                print(f"Current learning rate: {current_lr}")
                print(f"Stopping training early: Learning rate has dropped below lr_threshold")
                break 

        torch.cuda.empty_cache()
      
        model.eval()
        model_output,rate = model(coords)   
       
        bits_rate_eval=rate.sum()/(args.all_pix_num)
        bits_rate_eval_num = rate.sum()

        img_out=model_output.view(batch_size,height,width,3).permute(0,3,1,2)
        distortion_eval,_=wd_loss(img_out, img_in,model.log2_sigma, num_scales=3,saliency=model.saliency_tensor)
       
        torch.cuda.empty_cache()
        torch.save(checkpoint, saved_path)
        print('Saved model at',saved_path)
        torch.cuda.empty_cache()

    return distortion_eval.item(),bits_rate_eval.item(), bits_rate_eval_num.item()

global args

parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--batch_size', type=int, default=1, help='Batch-size')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate')
parser.add_argument('--data', type=str, default='../data', help='Location to store data')

parser.add_argument('--p_min', type=float, default=0.5, help='lower bounding likelihood')
parser.add_argument('--sigma_max', type=float, default=16, help='sigma_max：8-16, also need to change wd.py')

parser.add_argument('--local_upsampling_kernel_size', type=int, default=8, help='2, 4 or 8')
parser.add_argument('--upsampling_kernel_size', type=int, default=8, help='2, 4 or 8')
parser.add_argument('--static_upsampling_kernel', default=True, help='Use this flag to **not** learn the upsampling kernel')
parser.add_argument('--latent_factor', type=int, default=1, help='Full resolution -> 1, other W,H/factor')
parser.add_argument('--mod_base', type=int, default=7, help='Number of base')

parser.add_argument('--highest_flag', type=int, default=1, help='Full resolution -> 1, other W,H/factor')
parser.add_argument('--context_arm', type=int, default=24, help='8,16,24,32')
parser.add_argument('--dim_arm_mod', type=int, default=24, help='arm dimension')
parser.add_argument('--dataset', type=str, default='clic', help='kodak or clic')

parser.add_argument('--mod_hid_layer', type=int, default=0, help='3x3 mod layer')
parser.add_argument('--part', type=int, default=3, help='0: full, 1, 2, 3,4')

parser.add_argument('--use_candidate', default=False, help='Use candidate')
parser.add_argument('--sythesis_features', type=int, default=18, help='hidden')
parser.add_argument('--hidden_features', type=int, default=64, help='hidden')
parser.add_argument('--scale', type=int, default=1, help='Predict every scale*1 pixel')

parser.add_argument('--wd_flag', type=int, default=1, help='Predict every scale*1 pixel')
parser.add_argument('--rec_flag', type=int, default=1, help='Rec or not')


parser.add_argument(
    '--lambda_rate_list',
    type=float,
    nargs='+', 
    default=[1.1],
    metavar='LR',
    help='list of lambda weights'
)

args = parser.parse_args()

all_psnr_list_of_lists = []
all_rate_list_of_lists = []
all_rate_num_list_of_lists = []
eval_all_psnr_list_of_lists = []
eval_all_y_rate_list_of_lists = []
eval_all_mlp_rate_list_of_lists = []
eval_all_y_rate_num_list_of_lists = []
eval_all_mlp_rate_num_list_of_lists = []
eval_all_border_rate_list_of_lists = []
eval_all_border_rate_num_list_of_lists = []
eval_all_total_rate_list_of_lists = []
eval_all_total_rate_num_list_of_lists = []
eval_all_arm_rate_list_of_lists=[]
eval_all_arm_rate_num_list_of_lists=[]
eval_all_conv_rate_list_of_lists=[]
eval_all_conv_rate_num_list_of_lists=[]


for num,lambda_rate in enumerate(args.lambda_rate_list):
    seed_everything(1)
    all_psnr=[]
    all_rate=[]
    all_rate_num = []
    eval_all_psnr=[]
    eval_all_y_rate=[]
    eval_all_y_rate_num=[]
    eval_all_mlp_rate=[]
    eval_all_mlp_rate_num=[]
    eval_all_rate_arm=[]
    eval_all_rate_arm_num=[]
    eval_all_rate_conv=[]
    eval_all_rate_conv_num=[]
    eval_all_border_rate=[]
    eval_all_border_rate_num=[]
    eval_all_total_rate=[]
    eval_all_total_rate_num=[]

    args.lambda_rate = lambda_rate
    if args.dataset=='clic':
         it_range = [0]
        

    for it in it_range:
        idx_str = f"{it + 1:02d}"
        if args.dataset=='clic': 
            val_folder = f'./dataset/clic_dataset/clic_data/clic{idx_str}'
            lossy_path = f'./dataset/clic_dataset/clic_lossy_mask/clic{idx_str}.png'
            lossyless_path = f'./dataset/clic_dataset/clic_new_mask/clic{idx_str}.png'
            saliency_path = f'./dataset/clic_dataset/clic_vis_saliency/clic{idx_str}_saliency.png'
        if args.dataset=='kodak':
            val_folder = f'./dataset/kodak_dataset/kodak_data/kodim{idx_str}'
            lossy_path = f'./dataset/kodak_dataset/kodak_lossy_mask/kodim{idx_str}.png'
            lossyless_path = f'./dataset/kodak_dataset/kodak_new_mask/kodim{idx_str}.png'
            saliency_path = f'./dataset/kodak_dataset/kodak_vis_saliency/kodim{idx_str}_saliency.png'
        print('train the',it,'-th image')

        args.lambda_rate = lambda_rate
        transform_val = transforms.Compose([
                transforms.ToTensor() ])
        val_dataset = datasets.ImageFolder(val_folder,transform_val)
        dataloader = torch.utils.data.DataLoader(val_dataset,batch_size=args.batch_size,shuffle=False,num_workers=1, pin_memory=True)
        img_in, _ = next(iter(dataloader))
        args.patch_h=img_in.shape[2]
        args.patch_w=img_in.shape[3]
       
        saliency_map = cv2.imread(saliency_path, cv2.IMREAD_GRAYSCALE)/255.0
        saliency_tensor = torch.from_numpy(saliency_map).float().unsqueeze(0).unsqueeze(0).cuda()  # shape: (1, 1, H, W)
        saliency_mean=saliency_tensor.mean()
        eps = 1e-6

        p=args.p_min+(1-args.p_min)*saliency_tensor/(saliency_mean+eps)
        sigma=args.sigma_max*args.p_min/p
        sigma = torch.clamp(sigma, min=1.0)       
        log2_sigma=torch.log2(sigma)
        
        target_mask_list = extract_individual_region_masks(lossy_path)
        
        args.all_pix_num = args.patch_h*args.patch_w
        args.eval_pix_num = args.patch_h*args.patch_w
        print(args)
        folder_path_dataset='./saved/'+str(args.dataset)
        make_path(folder_path_dataset)
        folder_path=folder_path_dataset+'/arm_'+str(args.dim_arm_mod)
        make_path(folder_path)
        saved_path=folder_path+'./'+args.dataset+'_'+str(args.dim_arm_mod)+'_'+str(args.sythesis_features)+'_pw_'+str(args.lambda_rate)+'_img_'+str(it)+'.pth'
        total_steps =80000
        total_steps_2=8000
        steps_til_summary = 50
        
        mask_model = ReReIC(args,target_mask_list,log2_sigma,saliency_tensor)
        
        print(mask_model)
        mask_model.cuda()
        out_distortion, out_rate, rate_num=train(target_mask_list, mask_model, dataloader, total_steps, total_steps_2,steps_til_summary,it,saved_path)
        all_psnr.append(out_distortion)
        all_rate.append(out_rate)
        all_rate_num.append(rate_num)
        print('Trained the image with Distortion:',out_distortion,' latent bits',out_rate, 'latent bits number', rate_num)
        print(all_psnr)
        print(all_rate)
        print(all_rate_num)

