import torch
import torch.nn as nn
import wandb
from tqdm import tqdm
from forward_process import *
from dataset import *
from dataset import *
import timm
import random
from torch import Tensor, nn
from typing import Callable, List, Tuple, Union
from unet import *
from omegaconf import OmegaConf
from sample import *
from visualize import *
from resnet import *
from de_resnet import de_wide_resnet50_2
import torchvision.transforms as T
from diffusers import AutoencoderKL


#os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2"

torch.manual_seed(42)

def build_model(config):
    #model = SimpleUnet()
    unet = UNetModel(256, 64, dropout=0, n_heads=4 ,in_channels=config.data.fe_input_channel)
    return unet



def patchify(features, return_spatial_info=False):
    """Convert a tensor into a tensor of respective patches.
    Args:
        x: [torch.Tensor, bs x c x w x h]
    Returns:
        x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,
        patchsize]
    """
    patchsize = 3
    stride = 1
    padding = int((patchsize - 1) / 2)
    unfolder = torch.nn.Unfold(
        kernel_size=patchsize, stride=stride, padding=padding, dilation=1
    )
    unfolded_features = unfolder(features)
    number_of_total_patches = []
    for s in features.shape[-2:]:
        n_patches = (
            s + 2 * padding - 1 * (patchsize - 1) - 1
        ) / stride + 1
        number_of_total_patches.append(int(n_patches))
    unfolded_features = unfolded_features.reshape(
        *features.shape[:2], patchsize, patchsize, -1
    )
    unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)
    max_features = torch.mean(unfolded_features, dim=(3,4))
    features = max_features.reshape(features.shape[0], int(math.sqrt(max_features.shape[1])) , int(math.sqrt(max_features.shape[1])), max_features.shape[-1]).permute(0,3,1,2)
    if return_spatial_info:
        return unfolded_features, number_of_total_patches
    return features



def loss_fucntion(a, b):
    cos_loss = torch.nn.CosineSimilarity()
    loss = 0
    for item in range(len(a)):
        loss += torch.mean(1-cos_loss(a[item].view(a[item].shape[0],-1),b[item].view(b[item].shape[0],-1)))
    return loss


def roundup(x, n=10):
    res = math.ceil(x/n)*n
    if (x%n < n/2)and (x%n>0):
        res-=n
    return res
              

def Domain_adaptation(unet, feature_extractor, vae, config, fine_tune, constants_dict, dataloader):
    feature_extractor = feature_extractor


    if fine_tune:      
        unet.eval()
        feature_extractor.train()
        
        for param in feature_extractor.parameters():
            param.requires_grad = True

        transform = transforms.Compose([
                    transforms.Lambda(lambda t: (t + 1) / (2)),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ])

        optimizer = torch.optim.AdamW(feature_extractor.parameters(),lr= 3e-4)      
        
        for epoch in range(config.model.DA_epochs):
            for step, batch in enumerate(dataloader):
                
                if config.model.DA_rnd_step:
                    step_percentage = np.random.randint(1,11)
                    step_size = config.model.test_trajectoy_steps_DA / 10 * step_percentage
                    test_trajectoy_steps = torch.Tensor([step_size]).type(torch.int64).to(config.model.device)
                    step_size = roundup(step_size)
                    skip = int(max(step_size / 10, 1))
                    seq = range(0 , step_size, skip)
                else:
                    
                    test_trajectoy_steps = torch.Tensor([config.model.test_trajectoy_steps_DA]).type(torch.int64).to(config.model.device)
                    seq = range(0 , config.model.test_trajectoy_steps_DA, config.model.skip_DA)
                    
                at = compute_alpha(constants_dict["betas"], test_trajectoy_steps.long(),config)
                
                
                if config.model.DA_half:
                    #DDAD DA
                    half_batch_size = batch[0].shape[0]//2
                    target = batch[0][:half_batch_size].to(config.model.device)  
                    
                    input = batch[0][half_batch_size:].to(config.model.device)  
                
                
                
                    target_vae = vae.encode(target.to(config.model.device)).latent_dist.sample() * 0.18215   
                    input = vae.encode(input.to(config.model.device)).latent_dist.sample() * 0.18215   
                    
                    noisy_image = at.sqrt() * input + (1- at).sqrt() * torch.randn_like(input).to('cuda')
                    

                    
                else:
                    target = batch[0].to(config.model.device)  
                    target_vae = vae.encode(target.to(config.model.device)).latent_dist.sample() * 0.18215   
                    if config.model.noise_sampling:
                        noise = torch.randn_like(target_vae).to(config.model.device)
                        
                        noisy_image = at.sqrt() * target_vae + (1- at).sqrt() * noise
                    else:
                        noisy_image = target_vae
                        if config.model.downscale_first:
                            noisy_image = noisy_image * at.sqrt()
                    
                
                    
                reconstructed, _ = DA_generalized_steps(target_vae, noisy_image, seq, unet, constants_dict["betas"], config, eta2=config.model.eta2 , eta3=0 , constants_dict=constants_dict ,eraly_stop = False)
                data_reconstructed = reconstructed[-1].to(config.model.device)

                data_reconstructed = 1 / 0.18215 * data_reconstructed
                data_reconstructed = vae.decode(data_reconstructed.to(config.model.device)).sample
                
                data_reconstructed = transform(data_reconstructed)
                reconst_fe = feature_extractor(data_reconstructed)

                target = transform(target)
                target_fe = feature_extractor(target)

                loss = loss_fucntion(reconst_fe, target_fe)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            print(f"Epoch {epoch} | Loss: {loss.item()}")
            if config.model.DA_half:
                torch.save(feature_extractor.state_dict(), os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,f'feature_{epoch+1}'))
            else:
                torch.save(feature_extractor.state_dict(), os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,f'feature_recon_sim{epoch+1}'))

    else:
        if config.model.DA_half:
            checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,f'feature_{config.model.DA_epochs}'))            
            feature_extractor.load_state_dict(checkpoint)  
            print("loaded fe DA_half")
        else:
            checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,f'feature_recon_sim{config.model.DA_epochs}')) 
            feature_extractor.load_state_dict(checkpoint)  
            print("loaded fe recon sim")
    return feature_extractor


