import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from functools import partial
from einops import rearrange
import utils.misc as utils
from utils.metrics import WRMSE
from utils.builder import get_optimizer, get_lr_scheduler
#import modules
from utils.distributions import DiagonalGaussianDistribution
from models.vaeformer import VAEformer
from models.diffusion.score import ScoreSiT
from models.diffusion.score import VPSDE
from models.diffusion.score import GuidanceSampling
import pandas as pd
import h5py

class EMAHelper(object):
    def __init__(self, mu=0.999):
        self.mu = mu
        self.shadow = {}

    def register(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data

    def ema(self, module):
        if isinstance(module, nn.DataParallel):
            module = module.module
        for name, param in module.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.shadow[name].data)

    def ema_copy(self, module):
        if isinstance(module, nn.DataParallel):
            inner_module = module.module
            module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device)
            module_copy.load_state_dict(inner_module.state_dict())
            module_copy = nn.DataParallel(module_copy)
        else:
            module_copy = type(module)(module.config).to(module.config.device)
            module_copy.load_state_dict(module.state_dict())
        self.ema(module_copy)
        return module_copy

    def state_dict(self):
        return self.shadow

    def load_state_dict(self, state_dict):
        self.shadow = state_dict

class DDPO_model(nn.Module):
    def __init__(self,):
        super().__init__()

        data_size = (69,32,64)

        diff_backbone1 = ScoreSiT(channels=69, input_size=(32,64), patch_size=(2,2), depth=28, hidden_size=1152, num_heads=16)
        self.diff_ref = VPSDE(diff_backbone1,data_size)
        #read the reference model and froze it
        refdiff_model_path = '/path/to/reference_model'
        state_dict = torch.load(refdiff_model_path)
        if 'eps.network.pos_embed' not in state_dict:
            state_dict['eps.network.pos_embed'] = self.diff_ref.eps.network.pos_embed
        if 'device' not in state_dict:
            state_dict['device'] = torch.empty(())
        self.diff_ref.load_state_dict(state_dict)

        for param in self.diff_ref.parameters():
            param.requires_grad = False

        diff_backbone2 = ScoreSiT(channels=69, input_size=(32,64), patch_size=(2,2), depth=28, hidden_size=1152, num_heads=16)
        self.diff_dpo = VPSDE(diff_backbone2,data_size)
        self.diff_dpo.load_state_dict(state_dict)
        for param in self.diff_dpo.parameters():
            param.requires_grad = True
        '''
        self.AE = VAEformer(model_version=69)
        
        self.AE.load_state_dict(torch.load("/path/to/VAE_model"))
        for param in self.AE.parameters():
            param.requires_grad = False'''
        
    def forward(self, bkg_latent, ana_w_latent, ana_l_latent):
        
        noise_prediction_w, noise_prediction_l, noise_prediction_w_ref, noise_prediction_l_ref, noise_w, noise_l, wt = self.diff_dpo.dpoloss(self.diff_ref.eps, ana_w_latent, ana_l_latent,bkg_latent)
        
        return noise_prediction_w, noise_prediction_l, noise_prediction_w_ref, noise_prediction_l_ref, noise_w, noise_l, wt
        
    def forward_ref(self, bkg_latent, ana_latent):
        
        noise_prediction, noise = self.diff_ref.loss(ana_latent,bkg_latent)

        return noise_prediction, noise
    def forward_tune(self, bkg_latent, ana_latent):
        
        noise_prediction, noise = self.diff_dpo.loss(ana_latent,bkg_latent)

        return noise_prediction, noise
    
    def sample(self,bkg_latent):
        #bkg_latent = self.AE.encode(bkg_field).mode()
        ana_latent = self.diff.sample((1), bkg_latent)
        #ana_latent = ana_latent*self.latent_std+self.latent_mean

        #ana_field = self.AE.decode(ana_latent)
        return ana_latent #,self.AE.decode(ana_latent)
    
    
    def get_latent(self,field):
        return self.AE.encode(field).mode()

class DDPO(object):
    def __init__(self, **model_params) -> None:
        super().__init__()

        print('model_params:',model_params)
        self.optimizer_params = model_params.get('optimizer', {})
        self.scheduler_params = model_params.get('lr_scheduler', {})

        self.kernel = DDPO_model()

        self.best_loss = 9999999
        self.criterion_mae = nn.L1Loss()
        self.criterion_mse = nn.MSELoss()

        if utils.is_dist_avail_and_initialized():
            self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
            if self.device == torch.device('cpu'):
                raise EnvironmentError('No GPUs, cannot initialize multigpu training.')
        else:
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    def train(self, train_data_loader, valid_data_loader, logger, args):
        
        with h5py.File("/path/to/latent_mean_std_h5_file", "r") as f:
            self.latent_mean = f["mean"][:]
            self.latent_std = f["std"][:]
        self.latent_mean = torch.from_numpy(self.latent_mean).float().to(self.device)
        self.latent_std = torch.from_numpy(self.latent_std).float().to(self.device)

        train_step = len(train_data_loader)
        valid_step = len(valid_data_loader)
        self.optimizer = torch.optim.AdamW(self.kernel.parameters(), lr=0.000001, weight_decay=0)

        ema_helper = EMAHelper()
        if utils.get_world_size() > 1 and utils.get_rank() == 0:
            ema_helper.register(self.kernel.module.diff_dpo)
        elif utils.get_world_size() == 1:
            ema_helper.register(self.kernel.diff_dpo)
        t_step = 0
        train_loss = 0
        model_w_loss = 0
        model_l_loss = 0
        for epoch in range(args.max_epoch):
            begin_time = time.time()
            self.kernel.train()
            
            for step, batch_data in enumerate(train_data_loader):
                
                bkg_latent = batch_data[0].float().to(self.device)
                win_latent = batch_data[1].float().to(self.device)
                loss_latent = batch_data[2].float().to(self.device)
                
                self.optimizer.zero_grad()
                
                noise_prediction_w, noise_prediction_l, noise_prediction_w_ref, noise_prediction_l_ref, noise_w, noise_l, wt = self.kernel(bkg_latent, win_latent, loss_latent)
                
                model_diff = self.criterion_mse(noise_prediction_w, noise_w) - self.criterion_mse(noise_prediction_l, noise_l)

                ref_diff =   self.criterion_mse(noise_prediction_w_ref, noise_w)  - self.criterion_mse(noise_prediction_l_ref, noise_l)
                #
                beta = 10000 #lambda*T in spin-diffusion paper
                if step==0 and epoch==0:
                    print(beta)
                loss = model_diff-ref_diff
                loss = torch.mean(-F.logsigmoid(-beta*loss))
                #loss = torch.mean(1+beta*loss)
                train_loss += loss.item()
                model_w_loss += self.criterion_mse(noise_prediction_w, noise_w).item()
                model_l_loss += self.criterion_mse(noise_prediction_l, noise_l).item()
                #print("step",step,"loss:",model_diff.item(),ref_diff.item(),model_diff.item()-ref_diff.item(),loss.item())
                if (step)%30 ==0 and utils.get_rank() == 0:
                    print("step:")
                    model_diff = self.criterion_mse(noise_prediction_w, noise_w) - self.criterion_mse(noise_prediction_l, noise_l)

                    ref_diff =   self.criterion_mse(noise_prediction_w_ref, noise_w)  - self.criterion_mse(noise_prediction_l_ref, noise_l)
       
                    print("model_w_loss:",self.criterion_mse(noise_prediction_w, noise_w).item(),"model_l_loss:",self.criterion_mse(noise_prediction_l, noise_l).item(),"diff:",model_diff.item())
                    print("ref_w_loss:",self.criterion_mse(noise_prediction_w_ref, noise_w).item(),"ref_l_loss:",self.criterion_mse(noise_prediction_l_ref, noise_l).item(),"diff:",ref_diff.item())
                    print("win_diff:",self.criterion_mse(noise_prediction_w, noise_w).item()-self.criterion_mse(noise_prediction_w_ref, noise_w).item())
                    print("lose_diff:",self.criterion_mse(noise_prediction_l, noise_l).item()-self.criterion_mse(noise_prediction_l_ref, noise_l).item())
                    print("train_loss:",train_loss/(step+1))
                    print("avergaed model w/l loss:",model_w_loss/(step+1),model_l_loss/(step+1))
                    print("    ")
                    
                #dspo_loss = self.kernel(bkg_latent, win_latent, loss_latent)
                
                loss.backward()
                if utils.get_world_size() > 1:
                    clip_grad_norm_(self.kernel.module.diff_dpo.parameters(), max_norm=1)
                else:
                    clip_grad_norm_(self.kernel.diff_dpo.parameters(), max_norm=1)
                self.optimizer.step()

                if utils.get_world_size() > 1 and utils.get_rank() == 0:
                    ema_helper.update(self.kernel.module.diff_dpo)
                elif utils.get_world_size() == 1:
                    ema_helper.update(self.kernel.diff_dpo)
                #self.scheduler.step()
                t_step = t_step + 1
                if ((step + 1) % 100 == 0) | (step+1 == train_step):
                    logger.info(f"Train epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{train_step}], lr:[{self.optimizer.param_groups[0]['lr']}], loss:[{loss.item()}]")
                
            self.kernel.eval()
            total_loss = 0
            total_loss_ref = 0
            with torch.no_grad():
                for step, batch_data in enumerate(valid_data_loader):

                    bkg_latent = batch_data[0].float().to(self.device)
                    ana_latent = batch_data[1].float().to(self.device)

                    if utils.get_world_size() > 1:
                        noise_prediction, noise = self.kernel.module.forward_tune(bkg_latent, ana_latent)
                        
                    else:
                        noise_prediction, noise = self.kernel.forward_tune(bkg_latent, ana_latent)
                
                    loss = self.criterion_mse(noise_prediction, noise).item()

                    if utils.get_world_size() > 1:
                        noise_prediction, noise = self.kernel.module.forward_ref(bkg_latent, ana_latent)
                        
                    else:
                        noise_prediction, noise = self.kernel.forward_ref(bkg_latent, ana_latent)
                
                    loss_ref = self.criterion_mse(noise_prediction, noise).item()
                    total_loss += loss
                    total_loss_ref += loss_ref
                    
                    
                    if ((step + 1) % 100 == 0) | (step+1 == valid_step):
                        #de_loss1 = self.criterion_mse(ana_latent,tru_latent)
                        logger.info(f'Valid epoch:[{epoch+1}/{args.max_epoch}], step:[{step+1}/{valid_step}], loss:[{loss}], loss_ref:[{loss_ref}]')
        
            if (total_loss/valid_step) < self.best_loss:
                if utils.get_world_size() > 1 and utils.get_rank() == 0:
                    torch.save(ema_helper.state_dict(), f'{args.rundir}/models_{epoch}.pth')
                elif utils.get_world_size() == 1:
                    torch.save(ema_helper.state_dict(), f'{args.rundir}/models_{epoch}.pth')
                logger.info(f'New best model appears in epoch {epoch+1}.')
                self.best_loss = total_loss/valid_step
            logger.info(f'Epoch {epoch+1} average loss:[{total_loss/valid_step}], time:[{time.time()-begin_time}]')
            logger.info(f'Epoch {epoch+1} average loss:[{total_loss_ref/valid_step}], time:[{time.time()-begin_time}]')
    
    def test(self,test_data_loader, logger, args):
        
        test_step = len(test_data_loader)
        #data_mean, data_std = test_data_loader.dataset.get_meanstd()
        #self.data_std = data_std.to(self.device)
        self.kernel.eval()
        with torch.no_grad():
            total_loss = 0
            total_mae = 0
            total_mse = 0
            total_rmse = 0

            for step, batch_data in enumerate(test_data_loader):

                bkg_latent = batch_data[0].to(self.device)
                ana_latent = batch_data[1].to(self.device)

                ana_sampled,ana_field = self.kernel.sample(bkg_latent)
                
                loss = self.criterion_mse(ana_latent, ana_sampled).item()
                total_loss += loss

                if ((step + 1) % 1 == 0) | (step+1 == test_step):
                    #print(x_pred.shape,x_target.shape)
                    print(bkg_latent.shape)
                    self.save_sample(step,ana_sampled,ana_latent)
                    #self.save_sample(step,ana_field)
                    
                    logger.info(f'Valid step:[{step+1}/{test_step}], loss:[{loss}]')
    
    def test_da(self,test_data_loader,logger,args):
        test_step = len(test_data_loader)
        data_mean, data_std = test_data_loader.dataset.get_meanstd()
        self.data_std = data_std.to(self.device)
        print(self.data_std.shape)
        self.kernel.eval()
        with h5py.File("/path/to/latent_mean_std_h5_file", "r") as f:
            self.latent_mean = f["mean"][:]
            self.latent_std = f["std"][:]
        self.latent_mean = torch.from_numpy(self.latent_mean).to(self.device)
        self.latent_std = torch.from_numpy(self.latent_std).to(self.device)
        with torch.no_grad():
            total_loss = 0
            total_mae = 0
            total_mse = 0
            total_rmse = 0
            for step, batch_data in enumerate(test_data_loader):
                
                inp_data = torch.cat([batch_data[0][0], batch_data[0][1]], dim=1)
                inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy()
                print(batch_data[1])
                truth_ana = batch_data[0][-1].to(self.device, non_blocking=True)
                #print(inp_data.shape)
                for _ in range(48 // 6):
                    predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth_ana.shape[1]]
                    inp_data = np.concatenate([inp_data[:,-truth_ana.shape[1]:], predict_data], axis=1)        

                #predict_data is the bkg field. shape is 128*256
                bkg_field = torch.from_numpy(predict_data).to(self.device)
                ana_field = F.interpolate(truth_ana, size=(128,256), mode='bilinear').to(self.device)
                ana_field_ = ana_field
                bkg_field_ = bkg_field

                
                rmse2 = WRMSE(batch_data[0][0].to(self.device),batch_data[0][-1].to(self.device),self.data_std)
                mse2 = self.criterion_mse(batch_data[0][0].to(self.device),batch_data[0][-1].to(self.device)).item()

                bkg_field = torch.cat([bkg_field[:, 3:4, :, :], bkg_field[:, 0:3, :, :],bkg_field[:, 4:, :, :] ], dim=1)
                ana_field = torch.cat([ana_field[:, 3:4, :, :], ana_field[:, 0:3, :, :],ana_field[:, 4:, :, :] ], dim=1)

                bkg_latent = self.kernel.get_latent(bkg_field)
                ana_latent = self.kernel.get_latent(ana_field)

                recons_bkg = self.kernel.AE.decode(bkg_latent)
                recons_bkg = torch.cat([recons_bkg[:,1:4,:,:], recons_bkg[:,0:1,:,:], recons_bkg[:,4:,:,:]], dim=1)
                recons_ana = self.kernel.AE.decode(ana_latent)
                recons_ana = torch.cat([recons_ana[:,1:4,:,:], recons_ana[:,0:1,:,:], recons_ana[:,4:,:,:]], dim=1)

                mse3 = self.criterion_mse(recons_bkg,bkg_field_).item()
                rmse1 = WRMSE(recons_bkg,recons_ana,self.data_std)
                mse1 = self.criterion_mse(recons_bkg,recons_ana).item()

                bkg_latent = (bkg_latent-self.latent_mean)/self.latent_std
                #bkg_latent = (bkg_latent/3+1)/2
                bkg_latent = bkg_latent.float().to(self.device)
                ana_sampled_latent = self.kernel.sample(bkg_latent)

                ana_sampled_latent =(ana_sampled_latent*self.latent_std+self.latent_mean).float().to(self.device)
                #ana_sampled_latent = (ana_sampled_latent*2-1)*3
                ana_sampled = self.kernel.AE.decode(ana_sampled_latent)


                y_pred = torch.cat([ana_sampled[:,1:4,:,:], ana_sampled[:,0:1,:,:], ana_sampled[:,4:,:,:]], dim=1)
                #ana_sampled
                y_target = torch.cat([ana_field[:,1:4,:,:], ana_field[:,0:1,:,:], ana_field[:,4:,:,:]], dim=1)
                #ana_field

                mae = self.criterion_mae(y_pred, y_target).item()
                mse = self.criterion_mse(y_pred, y_target).item()
                rmse = WRMSE(y_pred, y_target, self.data_std)
                mse2 = self.criterion_mse(y_pred, recons_ana).item()
                rmse2 = WRMSE(y_pred, recons_ana, self.data_std)
                loss = self.criterion_mse(ana_sampled_latent,ana_latent).item()

                total_loss += mse
                total_mae += mae
                total_mse += mse
                total_rmse += rmse
                if ((step + 1) % 1 == 0) | (step+1 == test_step):

                    with h5py.File("result_{}.h5".format(step), "w") as f:
                        f.create_dataset("sampled", data=y_pred.detach().cpu().numpy())
                        f.create_dataset("truth", data=y_target.detach().cpu().numpy())
                    logger.info(f'Valid step:[{step+1}/{test_step}], MAE:[{mae}], MSE:[{mse}] RMSE:[{rmse}], MSE1:[{mse1}] RMSE1:[{rmse1}], MSE2:[{mse2}] RMSE2:[{rmse2}],latent_loss [{loss}]')
                    

        logger.info(f'Average loss:[{total_loss/test_step}], MAE:[{total_mae/test_step}], MSE:[{total_mse/test_step}]')
        logger.info(f'Average RMSE:[{total_rmse/test_step}]')

    def test_da_guide(self,test_data_loader,logger,args):
        test_step = len(test_data_loader)
        data_mean, data_std = test_data_loader.dataset.get_meanstd()
        self.data_std = data_std.to(self.device)

        
        with h5py.File("/path/to/latent_mean_std_h5_file", "r") as f:
            self.latent_mean = f["mean"][:]
            self.latent_std = f["std"][:]
        self.latent_mean = torch.from_numpy(self.latent_mean).to(self.device)
        self.latent_std = torch.from_numpy(self.latent_std).to(self.device)
        def A(x):
            x =  (x*self.latent_std+self.latent_mean).float().to(self.device)
            return self.kernel.AE.decode(x)

        sampling = GuidanceSampling(A=A, std = 0.1, sde = self.kernel.diff, device = self.device)

        self.kernel.eval()
        with torch.no_grad():
            total_loss = 0
            total_mae = 0
            total_mse = 0
            total_rmse = 0
            for step, batch_data in enumerate(test_data_loader):
                
                inp_data = torch.cat([batch_data[0][0], batch_data[0][1]], dim=1)
                inp_data = F.interpolate(inp_data, size=(128,256), mode='bilinear').numpy()
                
                truth_ana = batch_data[0][-1].to(self.device, non_blocking=True)
                #print(inp_data.shape)
                for _ in range(48 // 6):
                    predict_data = args.forecast_model.run(None, {'input':inp_data})[0][:,:truth_ana.shape[1]]
                    inp_data = np.concatenate([inp_data[:,-truth_ana.shape[1]:], predict_data], axis=1)        

                #predict_data is the bkg field. shape is 128*256
                bkg_field = torch.from_numpy(predict_data).to(self.device)
                ana_field = F.interpolate(truth_ana, size=(128,256), mode='bilinear').to(self.device)
                ana_field_ = ana_field
                bkg_field_ = bkg_field
                
                bkg_field = torch.cat([bkg_field[:, 3:4, :, :], bkg_field[:, 0:3, :, :],bkg_field[:, 4:, :, :] ], dim=1)
                ana_field = torch.cat([ana_field[:, 3:4, :, :], ana_field[:, 0:3, :, :],ana_field[:, 4:, :, :] ], dim=1)

                bkg_latent = self.kernel.get_latent(bkg_field)
                ana_latent = self.kernel.get_latent(ana_field)

                recons_bkg = self.kernel.AE.decode(bkg_latent)
                recons_bkg = torch.cat([recons_bkg[:,1:4,:,:], recons_bkg[:,0:1,:,:], recons_bkg[:,4:,:,:]], dim=1)

                mse3 = self.criterion_mse(recons_bkg,bkg_field_).item()
                rmse1 = WRMSE(recons_bkg,ana_field_,self.data_std)
                mse1 = self.criterion_mse(recons_bkg,ana_field_).item()

                bkg_latent = (bkg_latent-self.latent_mean)/self.latent_std
                #bkg_latent = (bkg_latent/3+1)/2
                bkg_latent = bkg_latent.float().to(self.device)

                #get the guidance (simulated observation)
                obs_mask = torch.rand(ana_field.shape, device=self.device) >= 0.9
                observation = obs_mask*ana_field


                ana_sampled_latent = sampling.sample(shape=(1,),guidance=observation,c=bkg_latent,obs_mask=obs_mask)
                ana_sampled_latent =(ana_sampled_latent*self.latent_std+self.latent_mean).float().to(self.device)
                ana_sampled = self.kernel.AE.decode(ana_sampled_latent)
                #self.kernel.sample(bkg_latent)

                y_pred = torch.cat([ana_sampled[:,1:4,:,:], ana_sampled[:,0:1,:,:], ana_sampled[:,4:,:,:]], dim=1)
                #ana_sampled
                y_target = torch.cat([ana_field[:,1:4,:,:], ana_field[:,0:1,:,:], ana_field[:,4:,:,:]], dim=1)
                #ana_field

                mae = self.criterion_mae(y_pred, y_target).item()
                mse = self.criterion_mse(y_pred, y_target).item()
                rmse = WRMSE(y_pred, y_target, self.data_std)
                loss = self.criterion_mse(ana_sampled_latent,ana_latent).item()

                total_loss += mse
                total_mae += mae
                total_mse += mse
                total_rmse += rmse
                if ((step + 1) % 1 == 0) | (step+1 == test_step):

                    with h5py.File("result_{}.h5".format(step), "w") as f:
                        f.create_dataset("sampled", data=y_pred.detach().cpu().numpy())
                        f.create_dataset("truth", data=y_target.detach().cpu().numpy())
                    logger.info(f'Valid step:[{step+1}/{test_step}], MAE:[{mae}], MSE:[{mse}] RMSE:[{rmse}], MSE1:[{mse1}] RMSE1:[{rmse1}], latent_loss [{loss}] {rmse}')

        logger.info(f'Average loss:[{total_loss/test_step}], MAE:[{total_mae/test_step}], MSE:[{total_mse/test_step}]')
        logger.info(f'Average RMSE:[{total_rmse/test_step}]')


    
    def save_sample(self,step,x_pred,x_target):
        B = x_pred.shape[0]
        for i in range(B):
            data = x_pred[i].cpu().numpy()
            print(data.shape)
            stacked_data = data.reshape(-1, data.shape[2])
            print(stacked_data.shape)
            pd.DataFrame(stacked_data).to_csv('.result_{}_{}.csv'.format(step,i),index=False, header=False)

            data = x_target[i].cpu().numpy()
            stacked_data = data.reshape(-1, data.shape[2])
            pd.DataFrame(stacked_data).to_csv('result_{}_{}.csv'.format(step,i), index=False, header=False)

    