import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
import torch
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
import torch.nn as nn 

for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_name(i))

import torch.optim as optim
from torch.utils.data import DataLoader
import random
# from utils.val_utils import AverageMeter, compute_psnr_ssim
import time
# import utils
from snn import model
# from warmup_scheduler import GradualWarmupScheduler
from tqdm import tqdm
from utils.losses import FFTLoss, L1Loss, MSELoss
# from torch.utils.tensorboard import SummaryWriter
import argparse
from utils.dataset_utils import TrainDataset
from options import options as opt
# import wandb
from utils.schedulers import CosineAnnealingRestartCyclicLR, LinearWarmupCosineAnnealingLR
from torch.optim.lr_scheduler import MultiplicativeLR
from spikingjelly.activation_based import neuron, functional, surrogate, layer, encoding
from natsort import natsorted
from glob import glob
from restormer import Restormer
from PromptIR import PromptIR


def mkdirs(paths):
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)

def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def get_last_path(path, session):
	x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
	return x
model_dir = '/data/SDA/suxin/SNN/snn-dpdd-restormer-as-t/'
# writer = SummaryWriter(mkdir(model_dir))
mkdir(model_dir)


########################################################################################
teature_model = Restormer(
        inp_channels=3, 
        out_channels=3, 
        dim = 48,
        num_blocks = [4,6,6,8], 
        num_refinement_blocks = 4,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        dual_pixel_task = True).to('cuda:0')
checkpoint = torch.load('/data/SDA/suxin/SNN/dual_pixel_defocus_deblurring.pth')
filtered_state_dict = {k: v for k, v in checkpoint.items() if k in teature_model.state_dict()}

# 加载状态字典
teature_model.load_state_dict(filtered_state_dict, strict=False)
# teature_model.load_state_dict(checkpoint)
teature_model.to('cuda:0')
teature_model.eval()

device2 = torch.device("cuda:0")
##########################################################################################
def main():
    t=4
    print("Options")
    print(opt)
    student_model = model.to(device2)
    student_model.load_state_dict(torch.load('/data/SDA/suxin/SNN/snn-dpdd-restormer-as-t/best1.pth')["state_dict"])
    epochs = 300
    initial_lr = 4e-4
    warm_up = 20
    
    
    optimizer = optim.AdamW(student_model.parameters(), lr=initial_lr, eps=1e-5, weight_decay=0.05)
    
    scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer, warmup_start_lr=5e-7, warmup_epochs=warm_up, max_epochs=280)
   

    loss_f1 = L1Loss(loss_weight=1.0, reduction='mean')
    loss_f2 = FFTLoss(loss_weight=0.1, reduction='mean')
    
    loss_mse = MSELoss(loss_weight=1.0, reduction='mean').cuda()
    trainset = TrainDataset(opt)
    trainloader = DataLoader(trainset, batch_size=2, pin_memory=True, shuffle=True, persistent_workers=True,
                              drop_last=True, num_workers=16)
    
    # num_batch = len(trainloader)
    # hook_manager = HookManager()
    for epoch in range(epochs):
        print('epoch:{}'.format(epoch))
        
        for batch in tqdm(trainloader):
            
            input = batch[0].to(device2)
            gtimg = batch[1].to(device2)
            # input1 = batch[0].cuda()
            optimizer.zero_grad()
            student_model.train()
            functional.set_step_mode(student_model, step_mode='m')

            output = student_model(input)
            
            feats = student_model.get_feature(input)
            total_feat_out = teature_model.get_feature(input)
            
            train_loss = 0.0
            for idx, feature in enumerate(feats):
                # print(total_feat_out[idx].shape)
                # print(idx)
                feature = feature.mean(0).to(device2)
                loss1 = loss_mse(feature, total_feat_out[idx])
                train_loss += loss1.item()
                del feature, loss1
            del total_feat_out
            train_loss = train_loss/4
            loss = loss_f1(output, gtimg) + train_loss + loss_f2(output, gtimg)
            loss = loss_f1(output, gtimg)+ loss_f2(output, gtimg)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad() 
            functional.reset_net(student_model)
        print(loss)
        # writer.add_scalar('loss/epoch_loss', loss, epoch)
        
        scheduler.step()
        if epoch % 1 == 0:
            torch.save({'epoch': epoch,
                        'state_dict': student_model.state_dict(),
                        'optimizer': optimizer.state_dict()
                        }, os.path.join(model_dir, f"model_epoch_{epoch}.pth"))
            # torch.save(student_model.state_dict(),
            #            model_dir + '/modelepoch{}.pth'.format(epoch))

if __name__ == '__main__':
    main()
