from unet import *
from torch.optim.lr_scheduler import ReduceLROnPlateau



from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision import datasets, transforms
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"



transform = transforms.Compose([
    transforms.Resize((256,256),Image.ANTIALIAS),
#     transforms.RandomRotation(30,interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
#     transforms.Grayscale(1),
#     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] 
])

class MyDataset(Dataset):    
    def __init__(self, file_path, transform=transform):    
#         self.file_path = '/home/jovyan/train/good'
        self.file_path = file_path
        self.image_names = os.listdir(self.file_path)
        self.transform = transform
    def __getitem__(self, idx):
#         name1 = random.sample(self.image_names,1)
#         name1 = ''.join(name1)
        name1 = self.image_names[idx]
        
        img_name = os.path.join(self.file_path, name1)
        image = Image.open(img_name).convert('RGB')
        if self.transform:
            image= self.transform(image)
        return image 
    def __len__(self):
        return len(self.image_names)
    



device = "cuda"
model_mu = UNetModel(256, 128, dropout=0.1, n_heads=4 ,in_channels=3)
model_logvar = UNetModel(256, 128, dropout=0.1, n_heads=4 ,in_channels=3)
model_mu =nn.DataParallel(model_mu)
model_mu.to(device)
model_logvar = nn.DataParallel(model_logvar)
model_logvar.to(device)
# model_mu.train()
# model_logvar.train()

optimizer_mu = torch.optim.Adam(model_mu.parameters(), lr=2e-5)#2e-5
optimizer_logvar = torch.optim.Adam(model_logvar.parameters(), lr=2e-5)#2e-5

category = 'tubes'#['bottle', 'cable', 'capsule', 'hazelnut', 'metal_nut', 'pill', 'screw', 
                    #'toothbrush', 'transistor', 'zipper','carpet','grid', 'leather', 'tile', 'wood'] 


#['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill', 'transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood']    
         # ['candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2', 'pcb1', 'pcb2' ,'pcb3', 'pcb4', 'pipe_fryum']
        #[01,02,03]
#         ['bracket_black','bracket_brown','bracket_white',
#     'connector','metal_plate','tubes']
# file_path = f'/home/jovyan/MVTec/{category}/train/good'
file_path = f'/home/jovyan/dataset/MPDD/{category}/train/good'
# file_path = f'/home/jovyan/visa/{category}/Data/Images/Normal'
# file_path = f'/home/jovyan/dataset/BTech_Dataset_transformed/{category}/train/ok'
dataset = MyDataset(file_path=file_path)
# dataset = MyDataset(file_path='/home/jovyan/MVTec/bottle/train/good')
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=24, shuffle=True,num_workers=8,drop_last=False)#24

# model_mu.load_state_dict(torch.load(f'./duad_mu_epoch_{category}_500.pth'))
# model_logvar.load_state_dict(torch.load(f'./ck_xishu/MVTec/{category}/duad_logvar_epoch_{category}_1000.pth'))
# model_mu.load_state_dict(torch.load(f'./ck_new/MPDD/{category}/duad_mu_epoch_{category}_2000.pth',map_location='cpu'))
# model_logvar.load_state_dict(torch.load(f'./ck_new/MPDD/{category}/duad_logvar_epoch_{category}_2000.pth',map_location='cpu'))


model_mu.train()
model_logvar.train()







# scheduler_mu = ReduceLROnPlateau(optimizer_mu, 'min', factor=0.5, patience=10)
# scheduler_logvar = ReduceLROnPlateau(optimizer_logvar, 'min', factor=0.5, patience=10)



timesteps = 1000#1000
scale = 1000 / timesteps
# beta_start = scale * 0.0001
# beta_end = scale * 0.02
beta_start = scale * 0.1
beta_end = scale * 20
betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)

            
alphas = 1. - betas*0.001
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
        
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod)

sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device)


num_train_epochs = 1000

for epoch in range(num_train_epochs):  
    loss_epoch = 0
    for step, (images) in enumerate(data_loader):
        optimizer_mu.zero_grad()
        optimizer_logvar.zero_grad()
        images = images.to(device).float()
        t = torch.randint(0, 1000, (images.shape[0],), device=images.device)
        noise = torch.randn_like(images).float()        
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float()
        sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float()
        perturbed_data = sqrt_alphas_cumprod_t * images + sqrt_one_minus_alphas_cumprod_t * noise
#         predicted_mu = model_duad_mu(perturbed_data,t)
        predicted_noise = model_mu(perturbed_data,t)
        predicted_logvar = model_logvar(perturbed_data,t)
        predicted_var = torch.exp(predicted_logvar)
#         loss =  torch.sum((perturbed_data-predicted_mu)* (perturbed_data-predicted_mu)/predicted_var+predicted_logvar)/images.shape[0]

#         loss =  torch.sum((noise-predicted_noise)* (noise-predicted_noise)/predicted_var+predicted_logvar)/images.shape[0]
    

        betas_t = betas[t].reshape(t.shape[0],1,1,1).float().to(device)
        alphas_t = alphas[t].reshape(t.shape[0],1,1,1).float().to(device)
        alphas_cumprod_t = alphas_cumprod[t].reshape(t.shape[0],1,1,1).float().to(device)
#         a1 = betas_t*betas_t/(alphas_t*(1-alphas_cumprod_t))
        
        

        a1 = 1-alphas_cumprod_t
        loss =  torch.sum(a1*(noise-predicted_noise)* (noise-predicted_noise)/predicted_var+predicted_logvar)/images.shape[0]
    
        loss.backward()
        optimizer_mu.step()
        optimizer_logvar.step()
        loss_epoch = loss_epoch+loss.item()
    print('loss:',epoch,loss.item())
    print('var:',torch.max(predicted_var),torch.min(predicted_var))
    if epoch%10 == 0:
        print('loss_epoch:',loss_epoch)
    if epoch%500 ==0:
        print('save epoch',epoch)
#         save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{save_step}.pth'), state)
        torch.save(model_mu.state_dict(),f'0125_duad_mu_epoch_{category}_{epoch}.pth')
        torch.save(model_logvar.state_dict(),f'0125_duad_logvar_epoch_{category}_{epoch}.pth')
#     scheduler_mu.step(loss_epoch)
#     scheduler_logvar.step(loss_epoch)