import os
#os.environ["CUDA_VISIBLE_DEVICES"] ="0"
import torch
from torch import nn
import torch.nn.functional as F
import argparse
import random
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from torch import  nn
from models.model import MoRIC
from utils.quantizer import quantize
from utils.arm import (
    Arm,
    _get_neighbor,
    _get_non_zero_pixel_ctx_index,
    _laplace_cdf,
)
from utils.eval_model import eval_model
from enc.misc import (
    MAX_ARM_MASK_SIZE,
    POSSIBLE_DEVICE,
    DescriptorCoolChic,
    DescriptorNN,
    measure_expgolomb_rate,
)
import cv2
from lossless_contour_algorthm import get_border_bits
from C_star_contour_algorithm import get_border_bits as get_border_bits_c_star

def get_mgrid(w_sidelen,h_sidelen, dim=2):
    x = torch.linspace(-1, 1, steps=w_sidelen)  # Linspace for width
    y = torch.linspace(-1, 1, steps=h_sidelen)  # Linspace for height
    tensors = (x, y) if dim == 2 else (x, ) * dim  # Extend for higher dimensions if needed
   
    mgrid = torch.stack(torch.meshgrid(*tensors, indexing='ij'), dim=-1) 

    mgrid = mgrid.unsqueeze(0).permute(0,3,2,1)
    return mgrid


def make_path(path):
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Directory '{path}' created.")
    else:
        print(f"Directory '{path}' already exists.")
    return  0

def loss_to_psnr(loss, max=1):
  return 10*np.log10(max**2/np.asarray(loss))
def get_mask_h_w(mask_path):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    y_indices, x_indices = np.where(mask == 255)
    target_mask = mask == 255

    if len(x_indices) > 0 and len(y_indices) > 0:
        min_x, max_x = x_indices.min(), x_indices.max()
        min_y, max_y = y_indices.min(), y_indices.max()

        width = max_x - min_x + 1
        height = max_y - min_y + 1
        cropped_mask = target_mask[min_y:max_y + 1, min_x:max_x + 1]
    return width, height, torch.from_numpy(cropped_mask).unsqueeze(0).unsqueeze(0)

def mm(mask_path):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # (h, w)
    target_mask = mask == 255
    target_mask_flat = target_mask.flatten()
    target_mask_tensor = torch.from_numpy(target_mask_flat).bool()

    return target_mask_tensor, torch.from_numpy(target_mask).unsqueeze(0).unsqueeze(0)

def train(target_mask,target_mask_lossless, model,dataloader, total_steps, total_steps_2,steps_til_summary,img_index,saved_path):

    criterion = nn.MSELoss().cuda()
    base_params = [p for name, p in model.named_parameters()]
    
    optim = torch.optim.Adam([{'params': base_params, 'lr': args.lr}])
    scheduler = CosineAnnealingLR(optim, T_max=total_steps)
    
    optimizer_stage_2 =torch.optim.Adam([p for p in model.parameters() if p.requires_grad],lr=1e-4)
    scheduler_stage_2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_stage_2, mode='min', factor=0.8, patience=20, verbose=True)

    for batch_idx, (img_in,_) in enumerate(dataloader, 0):
        model.train()
        vis_colum=1
        batch_size,_,height,width=img_in.shape
        pixels = img_in.permute(0, 2, 3, 1).view(batch_size,-1, 3).cuda()

        pixels1 = pixels[:,target_mask,:]
        pixels2 = pixels[:,target_mask_lossless,:]

       
        coords = get_mgrid(width,height, 2).cuda()
        losses = []
        losses_2 = []

        ####linear decay the noise para and tempertature
        initial_noise_param = 2.0
        final_noise_param = 1.0
        initial_temperature = 0.3
        final_temperature = 0.1
        print('start temperature:', initial_temperature,'noise parameter:',initial_noise_param)
        print('end temperature:', final_temperature,'noise parameter:',final_noise_param)

        mask = torch.zeros_like(pixels)
        mask[:,target_mask_lossless,:]=1.0
        mask = mask[:,target_mask,:]

        print("********************Start from stage I")
        best_rd=1000
        for step in range(total_steps+1):
            ###stage 1:
            model.train()
            model.noise_parameter = initial_noise_param - (step / total_steps) * (initial_noise_param - final_noise_param)
            model.soft_round_temperature = initial_temperature - (step / total_steps) * (initial_temperature - final_temperature)
            model_output,rate= model(coords)   
            bits_rate=rate.sum()/(args.all_pix_num)
            loss_mse=criterion(model_output, pixels1)
            loss=args.lambda_rate*bits_rate+loss_mse
            print_bits_rate = rate.sum()/args.eval_pix_num
            losses.append(loss.item())
            if not step % steps_til_summary or (step==total_steps-1):
                model_out1 = torch.zeros_like(pixels)
                model_out1[:,target_mask,:]=model_output
                model_out1 = model_out1[:,target_mask_lossless,:]
                
                print_loss_mse=criterion(model_out1,pixels2)
                print_psnr_this_iter=loss_to_psnr(print_loss_mse.item())

                psnr_this_iter=loss_to_psnr(loss_mse.item())
              
                if (loss<best_rd) and (step>0):
                    best_psnr=psnr_this_iter
                    best_rd=loss
                    checkpoint = {
                        'model_state_dict': model.state_dict(),
                    }
                    print("Step %d, BEST RD cost: %0.6f, PSNR: %0.6f, Total loss %0.6f" % (step,best_rd, print_psnr_this_iter,loss),'with its rate', print_bits_rate.item(), 'latent_bits', rate.sum().item())
                   
            optim.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], 10,norm_type=2.0, error_if_nonfinite=False)
           
            optim.step()
            scheduler.step()
        torch.save(checkpoint, saved_path)
        checkpoints=torch.load(saved_path)
        model.load_state_dict(checkpoints['model_state_dict'])
        ###stage_2
        print("********************going into stage II")
        best_psnr_2=0
        best_rd_2=1000
        for step in range(total_steps_2):
            model.train()
            model.quantizer_type="softround_alone"
            model.quantizer_noise_type="none"
            model.soft_round_temperature=1e-4
           
            model_output,rate = model(coords)   
          
            bits_rate=rate.sum()/(args.all_pix_num)
            loss_mse=criterion(model_output,pixels1)
          
            loss_2=args.lambda_rate*bits_rate+loss_mse
            losses_2.append(loss_2.item())
            if not step % steps_til_summary or (step==total_steps_2-1):
                psnr_this_iter=loss_to_psnr(loss_mse.item())
                if (loss_2<best_rd_2) and (step>0):
                    best_psnr_2=psnr_this_iter
                    best_rd_2=loss_2
                    checkpoint = {
                        'model_state_dict': model.state_dict(),
                    }
                    print('Print rate', bits_rate)
                    print('latent_bits_num', rate.sum().item())
                    print("Step %d, BEST RD cost:  %0.6f, PSNR: %0.6f, Total loss %0.6f" % (step, best_rd_2,psnr_this_iter,loss_2))
                   
            optimizer_stage_2.zero_grad()
            loss_2.backward()
            optimizer_stage_2.step()
            scheduler_stage_2.step(loss_2)
            current_lr = optimizer_stage_2.param_groups[0]['lr']
            if current_lr < 1e-8:
                print(f"Current learning rate: {current_lr}")
                print(f"Stopping training early: Learning rate has dropped below lr_threshold")
                break 
        torch.cuda.empty_cache()
        #eval:
        model.eval()
        model_output,rate = model(coords)   
        model_out = torch.zeros_like(pixels)
        model_out[:,target_mask,:]=model_output
        model_out = model_out[:,target_mask_lossless,:]
        bits_rate_eval=rate.sum()/(args.eval_pix_num)
        bits_rate_eval_num = rate.sum()
        loss_mse=criterion(model_out,pixels2)
        psnr_eval=loss_to_psnr(loss_mse.item())
        print("********************Evaluation the Image %d-th, after Step %d, BEST PSNR: %0.6f, Print rate %0.6f. *************************" % (img_index,step, psnr_eval,bits_rate_eval.item()))
        torch.cuda.empty_cache()
        torch.save(checkpoint, saved_path)
        print('Saved model at',saved_path)
    return psnr_eval,bits_rate_eval.item(), bits_rate_eval_num.item()

global args
parser = argparse.ArgumentParser(description='MoRIC Example')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate')
parser.add_argument('--upsampling_kernel_size', type=int, default=4, help='2, 4 or 8')
parser.add_argument('--if_lossy', default=False, help='use lossy contour')
parser.add_argument('--context_arm', type=int, default=8, help='8,16,24,32')
parser.add_argument('--dim_arm_mod', type=int, default=8, help='arm dimension')
parser.add_argument('--sythesis_features', type=int, default=5, help='hidden')
parser.add_argument('--lambda_rate_list', type=float, default=1e-3,help='list of lambda weights')
args = parser.parse_args()
all_psnr_list_of_lists = []
all_rate_list_of_lists = []
all_rate_num_list_of_lists = []
eval_all_psnr_list_of_lists = []
eval_all_y_rate_list_of_lists = []
eval_all_mlp_rate_list_of_lists = []
eval_all_y_rate_num_list_of_lists = []
eval_all_mlp_rate_num_list_of_lists = []
eval_all_border_rate_list_of_lists = []
eval_all_border_rate_num_list_of_lists = []
eval_all_total_rate_list_of_lists = []
eval_all_total_rate_num_list_of_lists = []
eval_all_arm_rate_list_of_lists=[]
eval_all_arm_rate_num_list_of_lists=[]
eval_all_conv_rate_list_of_lists=[]
eval_all_conv_rate_num_list_of_lists=[]


#for lambda_rate in args.lambda_rate_list:
lambda_rate=args.lambda_rate_list
all_psnr=[]
all_rate=[]
all_rate_num = []
eval_all_psnr=[]
eval_all_y_rate=[]
eval_all_y_rate_num=[]
eval_all_mlp_rate=[]
eval_all_mlp_rate_num=[]
eval_all_rate_arm=[]
eval_all_rate_arm_num=[]
eval_all_rate_conv=[]
eval_all_rate_conv_num=[]
eval_all_border_rate=[]
eval_all_border_rate_num=[]
eval_all_total_rate=[]
eval_all_total_rate_num=[]

args.lambda_rate = lambda_rate

for it in range(0,1):
    val_folder='./dataset/img_data/davis0'+str(it+1)
    mask_path='./dataset/img_data/davis_mask/davis0'+str(it+1)+'.png'
    if args.if_lossy == False:
        lossy_mask_path='./dataset/img_data/davis_mask/davis0'+str(it+1)+'.png'
    else:
        get_border_bits_c_star(mask_path,it,T=5,thread=10,rate=0.3)
        lossy_mask_path = './dataset/lossy_mask/davis0'+str(it+1)+'.png'

    args.lambda_rate = lambda_rate
    transform_val = transforms.Compose([transforms.ToTensor() ])
    val_dataset = datasets.ImageFolder(val_folder,transform_val)
    dataloader = torch.utils.data.DataLoader(val_dataset,batch_size=1,shuffle=False,num_workers=1, pin_memory=True)
    img_in, _ = next(iter(dataloader))
    args.patch_h=img_in.shape[2]
    args.patch_w=img_in.shape[3]
    ifmask = True
    target_mask_tensor_lossless, target_mask_lossless = mm(mask_path)
    target_mask_tensor, target_mask = mm(lossy_mask_path)

    args.all_pix_num = target_mask_tensor.sum()
    args.eval_pix_num = target_mask_tensor.sum()
    
    print(args)
    folder_path='./save_checkpoint/'
    make_path(folder_path)
    folder_path_=folder_path
    make_path(folder_path_)
    saved_path=folder_path_+'/MoRIC_pw_'+str(args.lambda_rate)+'_use_lossy'+str(args.if_lossy)+'.pth'

    total_steps = 100
    total_steps_2= 80
    steps_til_summary = 10

    mask_model = MoRIC(args,target_mask)
    
    target_mask = target_mask.flatten()
    target_mask_lossless = target_mask_lossless.flatten()
    print(mask_model)
    mask_model.cuda()

    print('train the',it,'-th image')
    out_psnr, out_rate, rate_num=train(target_mask,target_mask_lossless, mask_model, dataloader, total_steps, total_steps_2,steps_til_summary,it,saved_path)
    all_psnr.append(out_psnr)
    all_rate.append(out_rate)
    all_rate_num.append(rate_num)
    print('Trained the image with PSNR:',out_psnr,' latent bits',out_rate, 'latent bits number', rate_num)
    print(all_psnr)
    print(all_rate)
    print(all_rate_num)

    ###eval:
    ###load model:
    checkpoints=torch.load(saved_path)
    mask_model.load_state_dict(checkpoints['model_state_dict'])

    binary_mask=None
    print('load the model:',saved_path, 'for the ',it,'-th image')
    mask_model.cuda()
    mask_model.eval()
    
    eval_out_psnr,eval_y_rate, eval_y_rate_num,eval_network_rate,eval_network_rate_num,eval_network_rate_arm,eval_network_rate_arm_num,eval_network_rate_conv,eval_network_rate_conv_num=eval_model(target_mask,target_mask_lossless, args,mask_model, dataloader,it)
    eval_all_psnr.append(eval_out_psnr)
    eval_all_y_rate.append(eval_y_rate)
    eval_all_y_rate_num.append(eval_y_rate_num)
    eval_all_mlp_rate.append(eval_network_rate)
    eval_all_mlp_rate_num.append(eval_network_rate_num)
    eval_all_rate_arm.append(eval_network_rate_arm)
    eval_all_rate_arm_num.append(eval_network_rate_arm_num)
    eval_all_rate_conv.append(eval_network_rate_conv)
    eval_all_rate_conv_num.append(eval_network_rate_conv_num)
    if args.if_lossy == False:
        eval_border_rate_num = get_border_bits(mask_path)
    else:
        eval_border_rate_num = get_border_bits_c_star(mask_path,it,T=5,thread=10,rate=0.3)
    eval_border_rate = (eval_border_rate_num/args.eval_pix_num).item()
    eval_all_border_rate.append(eval_border_rate)
    eval_all_border_rate_num.append(eval_border_rate_num)
    
    eval_all_rate_y_mlp_latent=[y + mlp + border  for y, mlp, border in zip(eval_all_y_rate, eval_all_mlp_rate, eval_all_border_rate)]
    eval_all_rate_y_mlp_latent_num=[y + mlp + border for y, mlp, border in zip(eval_all_y_rate_num, eval_all_mlp_rate_num, eval_all_border_rate_num)]
    eval_all_total_rate.append(eval_all_rate_y_mlp_latent)
    eval_all_total_rate_num.append(eval_all_rate_y_mlp_latent_num)
    print('Evaluate the image: PSNR:',eval_out_psnr,'All bits:',eval_all_rate_y_mlp_latent[-1],' latent bits:',eval_y_rate,' network bits:',eval_network_rate, 'border bits:',eval_border_rate)
    print('Evaluate arm network bits:',eval_network_rate_arm, 'Evaluate synthesis network bits:',eval_network_rate_conv)
    print('Image All bits num:',eval_all_rate_y_mlp_latent_num[-1],'latent bits num:',eval_y_rate_num, 'mlp bits num:', eval_network_rate_num, 'border bits num:', eval_border_rate_num)
    print('arm bits num:', eval_network_rate_arm_num, 'synthesis bits num:', eval_network_rate_conv_num)
    print(eval_all_psnr)
    print(eval_all_rate_y_mlp_latent)
    print(eval_all_rate_y_mlp_latent_num)
    print('Current eval Ave PSNR:',np.mean(eval_all_psnr),'Ave Bits',np.mean(eval_all_rate_y_mlp_latent))
    all_psnr_list_of_lists.append(all_psnr)
    all_rate_list_of_lists.append(all_rate)
    all_rate_num_list_of_lists.append(all_rate_num)
    eval_all_psnr_list_of_lists.append(eval_all_psnr)
    eval_all_y_rate_list_of_lists.append(eval_all_y_rate)
    eval_all_y_rate_num_list_of_lists.append(eval_all_y_rate_num)
    eval_all_mlp_rate_list_of_lists.append(eval_all_mlp_rate)
    eval_all_mlp_rate_num_list_of_lists.append(eval_all_mlp_rate_num)
    eval_all_border_rate_list_of_lists.append(eval_all_border_rate)
    eval_all_border_rate_num_list_of_lists.append(eval_all_border_rate_num)
    eval_all_total_rate_list_of_lists.append(eval_all_total_rate)
    eval_all_total_rate_num_list_of_lists.append(eval_all_total_rate_num)
    eval_all_arm_rate_list_of_lists.append(eval_all_rate_arm)
    eval_all_arm_rate_num_list_of_lists.append(eval_all_rate_arm_num)
    eval_all_conv_rate_list_of_lists.append(eval_all_rate_conv)
    eval_all_conv_rate_num_list_of_lists.append(eval_all_rate_conv_num)

    print('.......Copmlete all dataset training......')
    print('Ave Training PSNR:',np.mean(all_psnr),'Ave Training Bits',np.mean(all_rate))

    print('Training PSNR:', all_psnr)
    print('Training rate:', all_rate)

    print('Evalutation: Ave Eval PSNR:',np.mean(eval_all_psnr),'Ave Eval all bits:',np.mean(eval_all_rate_y_mlp_latent),'Ave Eval Bits',np.mean(eval_all_y_rate),'Ave Eval Network',np.mean(eval_all_mlp_rate))
    print('Eval All PSNR:',eval_all_psnr)
    print('Eval All rate',eval_all_rate_y_mlp_latent)
    print('Eval Latent rate',eval_all_y_rate)
    print('Eval MLP rate',eval_all_mlp_rate)
