import os
from option import opt
print(opt)
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id
from architecture import *
from utils import *
import torch
import scipy.io as scio
import time
import numpy as np
from torch.autograd import Variable
import datetime
import torch.nn.functional as F
import random
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from PIL import Image
scaler = torch.GradScaler(device="cuda")

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if not torch.cuda.is_available():
    raise Exception('NO GPU!')

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    

set_seed(42)  

# init mask
mask3d_batch_train, input_mask_train = init_mask(opt.mask_path, opt.input_mask, opt.batch_size)
mask3d_batch_test, input_mask_test = init_mask(opt.mask_path, opt.input_mask, 10)

# dataset
train_set = LoadTraining(opt.data_path)
test_data = LoadTest(opt.test_path)

# saving path
date_time = str(datetime.datetime.now())
date_time = time2file_name(date_time)
result_path = opt.outf + opt.name + '/result/'
model_path = opt.outf + opt.name + '/model/'
if not os.path.exists(result_path):
    os.makedirs(result_path)
if not os.path.exists(model_path):
    os.makedirs(model_path)
    
model = model_generator(opt.method, opt.pretrained_model_path).cuda()
model = torch.compile(model)

# optimizing
optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, betas=(0.9, 0.999))
if opt.scheduler=='MultiStepLR':
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)
elif opt.scheduler=='CosineAnnealingLR':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.max_epoch, eta_min=1e-6)
mse = torch.nn.MSELoss().cuda()


def train(epoch, logger):
    epoch_loss = 0
    begin = time.time()
    batch_num = int(np.floor(opt.epoch_sam_num / opt.batch_size))
    iteration_results = None
    
    for i in range(batch_num):
        gt_batch = shuffle_crop(train_set, opt.batch_size)
        gt = Variable(gt_batch).cuda().float()
        
        # chroma-intensity decomposition
        gt_chroma, gt_intensity = chroma_intensity_decom(gt)
        mask3d_batch_train_new = mask3d_batch_train*gt_intensity
        input_meas = init_meas(gt_chroma, mask3d_batch_train_new, opt.input_setting)
        input_mask_train_new = chroma_intensity_mask(input_mask_train, gt_intensity, input_mask_sign=opt.input_mask)
        
        is_last_batch = (i == batch_num - 1)
        
        # model forward
        optimizer.zero_grad()
        with torch.autocast("cuda", enabled=True):
            model_out = model(input_meas, input_mask_train_new)
            loss = torch.sqrt(mse(model_out, gt_chroma))

        epoch_loss += loss.item()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # loss.backward()
        # optimizer.step()
        
    end = time.time()
    logger.info("===> Epoch {} Complete: Avg. Loss: {:.6f} time: {:.2f}".
                format(epoch, epoch_loss / batch_num, (end - begin)))
    return epoch_loss / batch_num

def test(epoch, logger):
    psnr_list, ssim_list = [], []
    test_gt = test_data.cuda().float()
    test_gt_chroma, test_gt_intensity = chroma_intensity_decom(test_gt)
    mask3d_batch_test_new = mask3d_batch_test*test_gt_intensity     
    input_meas = init_meas(test_gt_chroma, mask3d_batch_test_new, opt.input_setting)
    input_mask_test_new = chroma_intensity_mask(input_mask_test, test_gt_intensity, input_mask_sign=opt.input_mask)
    model.eval()
    begin = time.time()
    with torch.no_grad():
        model_out = model(input_meas, input_mask_test_new)
        model_out = model_out*test_gt_intensity
    
    end = time.time()    
    for k in range(test_gt.shape[0]):
        psnr_val = torch_psnr(model_out[k, :, :, :], test_gt[k, :, :, :])
        ssim_val = torch_ssim(model_out[k, :, :, :], test_gt[k, :, :, :])
        psnr_list.append(psnr_val.detach().cpu().numpy())
        ssim_list.append(ssim_val.detach().cpu().numpy())
    pred = np.transpose(model_out.detach().cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    truth = np.transpose(test_gt.cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    psnr_mean = np.mean(np.asarray(psnr_list))
    ssim_mean = np.mean(np.asarray(ssim_list))
    logger.info('===> Epoch {}: testing psnr = {:.2f}, ssim = {:.3f}, time: {:.2f}'
                .format(epoch, psnr_mean, ssim_mean,(end - begin)))
    model.train()
    return pred, truth, psnr_list, ssim_list, psnr_mean, ssim_mean

def main():
    logger = gen_log(model_path)
    logger.info("Learning rate:{}, batch_size:{}.\n".format(opt.learning_rate, opt.batch_size))
    psnr_max = 0
    for epoch in range(1, opt.max_epoch + 1):
        epoch_loss = train(epoch, logger)        
        (pred, truth, psnr_all, ssim_all, psnr_mean, ssim_mean) = test(epoch, logger)
        scheduler.step()
        if psnr_mean > psnr_max:
            psnr_max = psnr_mean
            if psnr_mean > 28:
                name = result_path + '/' + 'Test_{}_{:.2f}_{:.3f}'.format(epoch, psnr_max, ssim_mean) + '.mat'
                scio.savemat(name, {'truth': truth, 'pred': pred, 'psnr_list': psnr_all, 'ssim_list': ssim_all})
                checkpoint(model, epoch, model_path, logger)

def save_iteration_results(iteration_results, epoch_count, save_path):
    """保存迭代过程中的中间结果"""
    save_dir = os.path.join(save_path, f'epoch_{epoch_count}')
    os.makedirs(save_dir, exist_ok=True)
    save_all_msp = []       
    save_all_avg = []    
    save_all_chroma = []
    save_all_chroma_avg = [] 
    
    for iter_idx, result in enumerate(iteration_results):
        result_sample = result[0]  # [28, H, W]
        min_val = result_sample.min()
        max_val = result_sample.max()        
        montage = torch.zeros(result_sample.shape[1], result_sample.shape[2] * 28)
        for c in range(28):
            channel_img = result_sample[c]  # [H, W]
            if max_val > min_val:
                channel_img = (channel_img - min_val) / (max_val - min_val)
            montage[:, c*result_sample.shape[2]:(c+1)*result_sample.shape[2]] = channel_img
        save_all_msp.append(montage)   
        
        avg_img = result_sample.mean(dim=0)  # [H, W]
        save_all_avg.append(avg_img)
        eps = 1e-6
        chroma = result_sample / (avg_img.unsqueeze(0) + eps)  # [28, H, W]
        
        chroma_avg = chroma.mean(dim=0)  # [H, W]
        save_all_chroma_avg.append(chroma_avg)
        
        chroma_min = chroma.min()
        chroma_max = chroma.max()
        chroma_montage = torch.zeros(result_sample.shape[1], result_sample.shape[2] * 28)
        for c in range(28):
            channel_chroma = chroma[c]
            if chroma_max > chroma_min:
                channel_chroma = (channel_chroma - chroma_min) / (chroma_max - chroma_min)
            chroma_montage[:, c*result_sample.shape[2]:(c+1)*result_sample.shape[2]] = channel_chroma
        save_all_chroma.append(chroma_montage)

    def save_combined_image(data_list, save_dir, filename):
        if not data_list:
            return
        height_per_iter = data_list[0].shape[0]
        width = data_list[0].shape[1]
        total_height = height_per_iter * len(data_list)
        combined = torch.zeros(total_height, width)
        for i, img in enumerate(data_list):
            y_start = i * height_per_iter
            y_end = (i + 1) * height_per_iter
            combined[y_start:y_end, :] = img
        combined_np = (combined.detach().cpu().numpy() * 255).astype(np.uint8)
        Image.fromarray(combined_np).save(os.path.join(save_dir, filename))
    
  
    def save_combined_avg_images(avg_list, chroma_avg_list, save_dir, filename):
        if not avg_list or not chroma_avg_list:
            return
        
        height = avg_list[0].shape[0]
        width = avg_list[0].shape[1]
        n_iters = len(avg_list)

        combined = torch.zeros(height * n_iters, width * 2)

        for i in range(n_iters):
            curr_avg = avg_list[i]
            avg_min, avg_max = curr_avg.min(), curr_avg.max()
            if avg_max > avg_min:
                curr_avg = (curr_avg - avg_min) / (avg_max - avg_min)

            y_start = i * height
            y_end = (i + 1) * height
            combined[y_start:y_end, 0:width] = curr_avg
            
            curr_chroma_avg = chroma_avg_list[i]
            chroma_min, chroma_max = curr_chroma_avg.min(), curr_chroma_avg.max()
            if chroma_max > chroma_min:
                curr_chroma_avg = (curr_chroma_avg - chroma_min) / (chroma_max - chroma_min)
                
            combined[y_start:y_end, width:width*2] = curr_chroma_avg
        
        combined_np = (combined.detach().cpu().numpy() * 255).astype(np.uint8)
        Image.fromarray(combined_np).save(os.path.join(save_dir, filename))
        
    save_combined_image(save_all_msp, save_dir, 'all_iterations_montage.png')       # 原始光谱
    save_combined_image(save_all_chroma, save_dir, 'all_iterations_chroma.png')     # 色度图
    save_combined_avg_images(save_all_avg, save_all_chroma_avg, save_dir, 'all_iterations_avg_comparison.png')  # 平均图像对比
    


if __name__ == '__main__':
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    main()


