import numpy as np
import wandb
import torch
from dc_ldm.util import instantiate_from_config
from omegaconf import OmegaConf
import torch.nn as nn
import os
from dc_ldm.models.diffusion.plms import PLMSSampler
from einops import rearrange, repeat
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import torch.nn.functional as F
from sc_mbm.mae_for_fmri import fmri_encoder
from config_cross import Config_MBM_finetune_cross
from sc_mbm.mae_for_fmri_cross import MAEforFMRICross
from sc_mbm.mae_for_image import ViTMAEConfig

def create_model_from_config(config, num_voxels, global_pool, is_cross_mae=False):
    if not is_cross_mae:
        model = fmri_encoder(num_voxels=num_voxels, patch_size=config.patch_size, embed_dim=config.embed_dim,
                    depth=config.depth, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, global_pool=global_pool)
    else:
        #config_mbm = Config_MBM_finetune_cross()
        #model_image_config = ViTMAEConfig.from_pretrained(config_mbm.vit_mae_model)
        #model_image_config.num_cross_encoder_layers = config_mbm.num_cross_encoder_layers
        #model_image_config.do_cross_residual = config_mbm.do_cross_residual
        #model_image_config.decoder_num_hidden_layers = config_mbm.img_decoder_layers
        #config_pretrain = config
        #model = MAEforFMRICross(num_voxels=num_voxels, patch_size=config_pretrain.patch_size, embed_dim=config_pretrain.embed_dim,
        #                    decoder_embed_dim=config_pretrain.decoder_embed_dim, depth=config_pretrain.depth, 
        #                    num_heads=config_pretrain.num_heads, decoder_num_heads=config_pretrain.decoder_num_heads, 
        #                    mlp_ratio=config_pretrain.mlp_ratio, focus_range=None, use_nature_img_loss=False, 
        #                    do_cross_attention=config_mbm.do_cross_attention, cross_encoder_config=model_image_config,
        #                    decoder_depth=config_mbm.fmri_decoder_layers)  
        model = MAEforFMRICross(num_voxels=config['num_voxels'], patch_size=config['patch_size'], embed_dim=config['embed_dim'],
                    decoder_embed_dim=config['decoder_embed_dim'], depth=config['depth'],
                    num_heads=config['num_heads'], decoder_num_heads=config['decoder_num_heads'],
                    mlp_ratio=config['mlp_ratio'], focus_range=None, use_nature_img_loss=False,
                    do_cross_attention=config['do_cross_attention'], cross_encoder_config=config['cross_img_encoder_config'],
                    decoder_depth=config['decoder_depth'])
    return model



class cond_stage_model(nn.Module):
    def __init__(self, metafile, num_voxels, cond_dim=1280, global_pool=True, is_cross_mae=False):
        super().__init__()
        # prepare pretrained fmri mae 
        num_voxels = (metafile['model']['pos_embed'].shape[1] - 1) * metafile["config"].patch_size
        model = create_model_from_config(metafile['config_merge'], num_voxels, global_pool, is_cross_mae=is_cross_mae)
        self.is_cross_mae = is_cross_mae
        if is_cross_mae:
            model.load_state_dict(metafile['model'],strict=True)
            self.mae = model
        else:
            model.load_checkpoint(metafile['model'])

        self.mae = model
        self.fmri_seq_len = model.num_patches
        self.fmri_latent_dim = model.embed_dim

        if global_pool == False:
            self.channel_mapper = nn.Sequential(
                nn.Conv1d(self.fmri_seq_len+1, (self.fmri_seq_len+1) // 2, 1, bias=True),
                nn.Conv1d((self.fmri_seq_len+1) // 2, 77, 1, bias=True)
            )
         
        self.dim_mapper = nn.Linear(self.fmri_latent_dim, cond_dim, bias=True)
        self.global_pool = global_pool
       
    def forward(self, x):
        # n, c, w = x.shape
        if not self.is_cross_mae:
            latent_crossattn = self.mae(x)
        else:
            latent_crossattn = self.mae.forward_encoder(x,mask_ratio=0)[0]
        if self.global_pool == False:
            latent_crossattn = self.channel_mapper(latent_crossattn)
   
        latent_crossattn = self.dim_mapper(latent_crossattn)
        out = latent_crossattn
        return out

class fLDM:

    def __init__(self, metafile, num_voxels, device=torch.device('cpu'),
                 pretrain_root='../pretrains/ldm/label2img',
                 logger=None, ddim_steps=250, global_pool=True, use_time_cond=True,
                 is_cross_mae=False, feature_adapter=None, fmri_vqgan_model=None):
        # define diffusion model here
        self.ckp_path = os.path.join(pretrain_root, 'model.ckpt')
        self.config_path = os.path.join(pretrain_root, 'config.yaml') 
        config = OmegaConf.load(self.config_path)
        config.model.params.unet_config.params.use_time_cond = use_time_cond
        config.model.params.unet_config.params.global_pool = global_pool

        self.cond_dim = config.model.params.unet_config.params.context_dim

        model = instantiate_from_config(config.model)
        pl_sd = torch.load(self.ckp_path, map_location="cpu")['state_dict']
        
        m, u = model.load_state_dict(pl_sd, strict=False)
        model.cond_stage_trainable = True
        model.cond_stage_model = cond_stage_model(metafile, num_voxels, self.cond_dim, global_pool=global_pool, is_cross_mae=is_cross_mae)
    
        model.ddim_steps = ddim_steps
        model.re_init_ema()
        if logger is not None:
            logger.watch(model, log="all", log_graph=False)

        model.p_channels = config.model.params.channels
        model.p_image_size = config.model.params.image_size
        model.ch_mult = config.model.params.first_stage_config.params.ddconfig.ch_mult
        
        self.device = device    
        self.model = model
        self.model.feature_adapter = feature_adapter
        self.model.fmri_vqgan_model = fmri_vqgan_model

        self.ldm_config = config
        self.pretrain_root = pretrain_root
        #self.fmri_latent_dim = model.cond_stage_model.fmri_latent_dim
        self.metafile = metafile

        ### get distillation model test


    def finetune(self, trainers, dataset, test_dataset, bs1, lr1,
                output_path, config=None):
        config.trainer = None
        config.logger = None
        self.model.main_config = config
        self.model.output_path = output_path
        # self.model.train_dataset = dataset
        self.model.run_full_validation_threshold = 0.15
        # stage one: train the cond encoder with the pretrained one
      
        # # stage one: only optimize conditional encoders
        print('\n##### Stage One: only optimize conditional encoders #####')
        dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
        self.model.unfreeze_whole_model()
        self.model.freeze_first_stage()

        self.model.learning_rate = lr1
        self.model.train_cond_stage_only = True
        self.model.eval_avg = config.eval_avg
        trainers.fit(self.model, dataloader, val_dataloaders=test_loader)
        
        self.model.unfreeze_whole_model()
        torch.save(
            {
                'model_state_dict': self.model.state_dict(),
                'config': config,
                'state': torch.random.get_rng_state()

            },
            os.path.join(output_path, 'checkpoint.pth')
        )
        

    @torch.no_grad()
    def generate(self, fmri_embedding, num_samples, ddim_steps, HW=None, limit=None, state=None, vqgan_model=None):
        # fmri_embedding: n, seq_len, embed_dim
        all_samples = []
        if HW is None:
            shape = (self.ldm_config.model.params.channels, 
                self.ldm_config.model.params.image_size, self.ldm_config.model.params.image_size)
        else:
            num_resolutions = len(self.ldm_config.model.params.first_stage_config.params.ddconfig.ch_mult)
            shape = (self.ldm_config.model.params.channels,
                HW[0] // 2**(num_resolutions-1), HW[1] // 2**(num_resolutions-1))

        model = self.model.to(self.device)
        sampler = PLMSSampler(model)
        # sampler = DDIMSampler(model)
        if state is not None:
            torch.cuda.set_rng_state(state)

        with model.ema_scope():
            model.eval()
            for count, item in enumerate(fmri_embedding):
                if limit is not None:
                    if count >= limit:
                        break
                latent = item['fmri']
                gt_image = rearrange(item['image'], 'h w c -> 1 c h w') # h w c
                print(f"rendering {num_samples} examples in {ddim_steps} steps.")
                # assert latent.shape[-1] == self.fmri_latent_dim, 'dim error'
                # item['pad_fmri'] = torch.tensor(item['pad_fmri']).to(self.device)
                # item['image']=torch.tensor(item['image']).to(self.device)
                # vqgan_model = vqgan_model.to(self.device)
                # x = vqgan_model(item['pad_fmri'].unsqueeze(0), 
                #  item['image'].unsqueeze(0).permute(0,3,1,2).float())
                
                c = model.get_learned_conditioning(repeat(latent, 'h w -> c h w', c=num_samples).to(self.device))
                samples_ddim, _ = sampler.sample(S= ddim_steps,
                                                #x_T = x[1], 
                                                conditioning=c,
                                                batch_size=num_samples,
                                                shape=shape,
                                                verbose=False)

                x_samples_ddim = model.decode_first_stage(samples_ddim)
                x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
                gt_image = torch.clamp((gt_image+1.0)/2.0, min=0.0, max=1.0)
                all_samples.append(torch.cat([gt_image, x_samples_ddim.detach().cpu()], dim=0)) # put groundtruth at first
                
        
        # display as grid
        grid = torch.stack(all_samples, 0)
        grid = rearrange(grid, 'n b c h w -> (n b) c h w')
        grid = make_grid(grid, nrow=num_samples+1)

        # to image
        grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
        model = model.to('cpu')
        
        return grid, (255. * torch.stack(all_samples, 0).cpu().numpy()).astype(np.uint8)


