import os
from asyncio import base_tasks
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import math
import argparse
import random
from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
import skimage
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.autograd as autograd
import torchvision.transforms as transforms
#import rff
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch import Tensor, index_select, nn
from models.model_codec_fix import Masked_INR
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

from utils.upsampling import Upsampling
from utils.eval_model import eval_model,compute_model_rate

#manual_seed=42
def seed_everything(seed=1029):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PATHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
 
def get_mgrid(w_sidelen,h_sidelen, dim=2):
   
    x = torch.linspace(-1, 1, steps=w_sidelen)  # Linspace for width
    y = torch.linspace(-1, 1, steps=h_sidelen)  # Linspace for height
    tensors = (x, y) if dim == 2 else (x, ) * dim  # Extend for higher dimensions if needed
    mgrid = torch.stack(torch.meshgrid(*tensors, indexing='ij'), dim=-1).permute(1,0,2)   # Use 'ij' indexing for (w, h) shape
    mgrid = mgrid.reshape(1, -1, dim)  
    return mgrid

def add_noise(img_index,img_tensor, save_rec_folder,noise_type='gaussian', noise_parameter=25):
    # Ensure float32
    img = img_tensor.float()

    if noise_type.lower() == 'gaussian':
        # sigma in pixel domain, convert to [0,1]
        std = noise_parameter / 255.0
        noise = torch.randn_like(img) * std
        noisy = img + noise

    elif noise_type.lower() == 'poisson':
        scale = noise_parameter
        noisy = torch.poisson(img * scale) / scale
    else:
        raise ValueError('noise_type must be "gaussian" or "poisson".')

    noisy = torch.clamp(noisy, 0.0, 1.0)

    return noisy

def make_path(path):
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Directory '{path}' created.")
    else:
        print(f"Directory '{path}' already exists.")
    return  0

def loss_to_psnr(loss, max=1):
  return 10*np.log10(max**2/np.asarray(loss))

def train(model,dataloader, total_steps, total_steps_2,steps_til_summary,img_index,save_img_folder,seed,fixed_noisy=None):
    vis_colum=3
    best_psnr=0
    criterion = nn.MSELoss().cuda()
    base_params = [p for name, p in model.named_parameters() if p.requires_grad and "scores" not in name and "log_scale_per_layer" not in name]
    score_params = [p for name, p in model.named_parameters() if p.requires_grad and "scores" in name]


    optim = torch.optim.Adam([{'params': base_params, 'lr': args.lr},{'params': score_params, 'lr': 0.1}])
    
    scheduler = CosineAnnealingLR(optim, T_max=total_steps)
    
    for batch_idx, (img_in_perfect,_) in enumerate(dataloader, 0):
        model.train()
        batch_size,_,height,width=img_in_perfect.shape
        img_in = fixed_noisy.to(img_in_perfect.device)
        pixels = img_in.permute(0, 2, 3, 1).view(batch_size,-1, 3).cuda()

        coords = get_mgrid(width,height, 2).repeat(batch_size,1,1).cuda()

        losses = []
        losses_2 = []        

        initial_noise_param = 2.0
        final_noise_param = 1.0
        initial_temperature = 0.3
        final_temperature = 0.1
        print('start temperature:', initial_temperature,'noise parameter:',initial_noise_param)
        print('end temperature:', final_temperature,'noise parameter:',final_noise_param)

        print("********************Start from stage I")
  
        for step in range(total_steps+1):
            model.train()
            model.noise_parameter = initial_noise_param - (step / total_steps) * (initial_noise_param - final_noise_param)
            model.soft_round_temperature = initial_temperature - (step / total_steps) * (initial_temperature - final_temperature)
            model_output,rate,_ = model(coords)   
            model_output=model_output.view(batch_size,-1,3)
            bits_rate=rate.sum()/(width*height)
            loss_mse=criterion(model_output,pixels)
            loss=args.lambda_rate*bits_rate+loss_mse
            losses.append(loss.item())
            if not step % steps_til_summary or (step==total_steps-1):
                psnr_this_iter=loss_to_psnr(loss_mse.item())
                img_out=model_output.view(batch_size,height,width,3).permute(0,3,1,2)
                vutils.save_image(img_out,save_img_folder+'/stage_1_'+str(img_index)+'_mask_'+str(args.sparsity)+'_seed_'+str(seed)+'.png',nrow=vis_colum)
            optim.zero_grad()

            loss.backward()
            all_params = [p for name, p in model.named_parameters()
              if p.requires_grad and "log_scale_per_layer" not in name]
            nn.utils.clip_grad_norm_(all_params, 10, norm_type=2.0, error_if_nonfinite=False)
            optim.step()           
        scheduler.step()
        print("********************going into stage II")
        optimizer_stage_2 =torch.optim.Adam([p for p in model.parameters() if p.requires_grad],lr=1e-4)
        scheduler_stage_2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_stage_2, mode='min', factor=0.8, patience=20)
        best_psnr_2=0
        for step in range(total_steps_2):
            model.train()
            model.quantizer_type="softround_alone"
            model.quantizer_noise_type="none"
            model.soft_round_temperature=1e-4
            model_output,rate,_ = model(coords)   
            model_output=model_output.view(batch_size,-1,3)
            bits_rate=rate.sum()/(width*height)
            loss_mse=criterion(model_output,pixels)
            loss_2=args.lambda_rate*bits_rate+loss_mse
            losses_2.append(loss_2.item())

            vutils.save_image(img_out,save_img_folder+'/stage_2_'+str(img_index)+'_mask_'+str(args.sparsity)+'_seed_'+str(seed)+'.png',nrow=vis_colum)
            optimizer_stage_2.zero_grad()
            loss_2.backward()
            optimizer_stage_2.step()
            scheduler_stage_2.step(loss_2)
            current_lr = optimizer_stage_2.param_groups[0]['lr']
            if current_lr < 1e-8:
                print(f"Current learning rate: {current_lr}")
                print(f"Stopping training early: Learning rate has dropped below lr_threshold")
                break 

        torch.cuda.empty_cache()
        
    return 0


global args
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--batch_size', type=int, default=1, help='Batch-size')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate')
parser.add_argument('--data', type=str, default='../data', help='Location to store data')
parser.add_argument('--sparsity', type=float, default=0.6, help='prune rate')
parser.add_argument('--upsampling_kernel_size', type=int, default=8, help='2, 4 or 8')
parser.add_argument('--static_upsampling_kernel', default=False, help='Use this flag to **not** learn the upsampling kernel')
parser.add_argument('--latent_factor', type=int, default=1, help='Full resolution -> 1, other W,H/factor')
parser.add_argument('--mod_base', type=int, default=7, help='Number of base')
parser.add_argument('--highest_flag', type=int, default=1, help='Full resolution -> 1, other W,H/factor')
parser.add_argument('--context_arm', type=int, default=32, help='8,16,24,32')
parser.add_argument('--dim_arm_mod', type=int, default=32, help='arm dimension')
parser.add_argument('--mod_hid_layer', type=int, default=0, help='3x3 mod layer')
parser.add_argument('--hidden_features', type=int, default=64, help='hidden')
parser.add_argument('--hidden_layer', type=int, default=8, help='layer')
parser.add_argument('--total_steps', type=int, default=50000, help='steps')
parser.add_argument('--part_flag', type=int, default=0, help='steps')

parser.add_argument('--lambda_rate', type=float, default=0.004, metavar='LR',help='weight')

args = parser.parse_args()
all_psnr=[]
all_rate=[]
eval_all_psnr=[]
eval_all_y_rate=[]
eval_all_mlp_rate=[]

if args.part_flag==0:
    start=0
    end=24
if args.part_flag==1:
    start=0
    end=12
if args.part_flag==2:
    start=12
    end=24

for it in range(start,end):
    if it<9:
        val_folder='./dataset/kodak_data_test/kodim0'+str(it+1)
    else:
        val_folder='./dataset/kodak_data_test/kodim'+str(it+1)
    
    transform_val = transforms.Compose([
            transforms.ToTensor() ])
    val_dataset = datasets.ImageFolder(val_folder,transform_val)
    dataloader = torch.utils.data.DataLoader(val_dataset,batch_size=args.batch_size,shuffle=False,num_workers=1, pin_memory=True)
    img_in, _ = next(iter(dataloader))
    args.patch_h=img_in.shape[2]
    args.patch_w=img_in.shape[3]


    print(args)
    total_steps =args.total_steps
    total_steps_2=total_steps//10
    steps_til_summary = 50
    print('top %:',args.sparsity)
    save_rec_folder=f'./rec_25_arm/re_kodim_{it+1:02d}/'
    make_path(save_rec_folder)
    noise_sigma = 25 
    seed_everything(42) 
    noisy_img = add_noise(it, img_in,save_rec_folder, noise_type='gaussian', noise_parameter=noise_sigma)

    model_seeds = [60,80,100]
    
    for m_seed in model_seeds:
        print("\n==============================")
        print(f" Train image {it} with MODEL SEED = {m_seed}")
        print("==============================\n")
        seed_everything(m_seed)
        mask_model = Masked_INR(args,sparsity=args.sparsity, hidden_layers=args.hidden_layer)
        print(mask_model)
        save_img_folder=save_rec_folder+str(args.context_arm)+'_'+str(total_steps)
        make_path(save_img_folder)
        mask_model.cuda()
        print('train the',it,'-th image with seed:',m_seed)
        train(mask_model, dataloader, total_steps, total_steps_2,steps_til_summary,it,save_img_folder,m_seed,fixed_noisy=noisy_img)
  