import torch
import tqdm
from core.base_model import BaseModel
from core.logger import LogTracker
from models.constraint import Constraint
import sys
import math
import copy
import numpy
import pandas as pd

class EMA():
    def __init__(self, beta=0.9999):
        super().__init__()
        self.beta = beta
    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)
    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

class Palette(BaseModel):
    def __init__(self, networks, losses, sample_num, task, optimizers, ema_scheduler=None, **kwargs):
        ''' must to init BaseModel with kwargs '''
        super(Palette, self).__init__(**kwargs)

        ''' networks, dataloder, optimizers, losses, etc. '''
        self.loss_fn = losses[0]
        self.netG = networks[0]
        if ema_scheduler is not None:
            self.ema_scheduler = ema_scheduler
            self.netG_EMA = copy.deepcopy(self.netG)
            self.EMA = EMA(beta=self.ema_scheduler['ema_decay'])
        else:
            self.ema_scheduler = None
        self.ema_scheduler = None
        
        ''' networks can be a list, and must convert by self.set_device function if using multiple GPU. '''
        self.netG = self.set_device(self.netG, distributed=self.opt['distributed'])
        if self.ema_scheduler is not None:
            self.netG_EMA = self.set_device(self.netG_EMA, distributed=self.opt['distributed'])
        self.load_networks()
        
        optimizers['lr'] = float(optimizers['lr'])
        self.optG = torch.optim.Adam(list(filter(lambda p: p.requires_grad, self.netG.parameters())), **optimizers)
        self.optimizers.append(self.optG)
        self.resume_training()

        if self.opt['distributed']:
            self.netG.module.set_loss(self.loss_fn)
            self.netG.module.set_new_noise_schedule(phase=self.phase)
        else:
            self.netG.set_loss(self.loss_fn)
            self.netG.set_new_noise_schedule(phase=self.phase)

        if self.method['method_type'] in ['ours', 'iclr']:
            self.freeze_blocks(self.netG.denoise_fn)
            if self.method['method_type'] == 'ours':
                self.optimizer = Constraint(self.netG.parameters(), base_optimizer=self.optG)
            else:
                self.optimizer = self.optG
        else:
            self.optimizer = self.optG
            

        '''clone origin model'''
        def clone_and_freeze_model(original_model: torch.nn.Module) -> torch.nn.Module:
            # 使用deepcopy复制模型，这样副本和原始模型之间不会有任何关联
            cloned_model = copy.deepcopy(original_model)

            # 冻结副本模型中的所有参数
            for param in cloned_model.parameters():
                param.requires_grad = False
                
            return cloned_model

        if self.phase == 'train':
            self.netG_clone = self.set_device(clone_and_freeze_model(self.netG))

        '''re-define loss function'''
        self.criterion = torch.nn.MSELoss(reduction='mean')

        ''' can rewrite in inherited class for more informations logging '''
        self.train_metrics = LogTracker(*['mse_forget', 'mse_retain'], phase='train')
        self.val_metrics = LogTracker(*['mse_forget', 'mse_retain'], phase='val')
        self.test_metrics = LogTracker(*['mse_forget', 'mse_retain'], phase='test')
        '''
        self.train_metrics = LogTracker(*[m.__name__ for m in losses], phase='train')
        self.val_metrics = LogTracker(*[m.__name__ for m in self.metrics], phase='val')
        self.test_metrics = LogTracker(*[m.__name__ for m in self.metrics], phase='test')
        '''

        self.sample_num = sample_num
        self.task = task

    def freeze_blocks(self, model):
        # 冻结 middle_block 的参数
        for param in model.middle_block.parameters():
            param.requires_grad = False

        # 冻结 output_blocks 的参数
        for block in model.output_blocks:
            for param in block.parameters():
                param.requires_grad = False
                
        self.logger.info('All parameters except for the encoder are now frozen.')
    
    def get_current_visuals(self, phase='train'):
        dict = {
            'gt_image_f': (self.gt_image_f.detach()[:].float().cpu()+1)/2,
            'cond_image_f': (self.cond_image_f.detach()[:].float().cpu()+1)/2,
            'gt_image_r': (self.gt_image_r.detach()[:].float().cpu()+1)/2,
            'cond_image_r': (self.cond_image_r.detach()[:].float().cpu()+1)/2,
        }
        if self.task in ['inpainting','uncropping']:
            dict.update({
                'mask_f': self.mask_f.detach()[:].float().cpu(),
                'mask_r': self.mask_r.detach()[:].float().cpu(),
                'mask_image_f': (self.mask_image_f+1)/2,
                'mask_image_r': (self.mask_image_r+1)/2,
            })
        if phase != 'train':
            dict.update({
                'output_f': (self.output_f.detach()[:].float().cpu()+1)/2,
                'output_r': (self.output_r.detach()[:].float().cpu()+1)/2,
            })
        return dict

    def save_current_results(self):
        ret_path = []
        ret_result = []
        for idx in range(self.batch_size):
            ret_path.append('forget_GT_{}'.format(self.path_f[idx]))
            ret_result.append(self.gt_image_f[idx].detach().float().cpu())
            ret_path.append('retain_GT_{}'.format(self.path_r[idx]))
            ret_result.append(self.gt_image_r[idx].detach().float().cpu())

            ret_path.append('forget_Process_{}'.format(self.path_f[idx]))
            ret_result.append(self.visuals_f[idx::self.batch_size].detach().float().cpu())
            ret_path.append('retain_Process_{}'.format(self.path_r[idx]))
            ret_result.append(self.visuals_r[idx::self.batch_size].detach().float().cpu())
            
            ret_path.append('forget_Out_{}'.format(self.path_f[idx]))
            ret_result.append(self.visuals_f[idx-self.batch_size].detach().float().cpu())
            ret_path.append('retain_Out_{}'.format(self.path_r[idx]))
            ret_result.append(self.visuals_r[idx-self.batch_size].detach().float().cpu())
        
        if self.task in ['inpainting','uncropping']:
            ret_path.extend(['forget_Mask_{}'.format(name) for name in self.path_f])
            ret_result.extend(self.mask_image_f)
            ret_path.extend(['retain_Mask_{}'.format(name) for name in self.path_r])
            ret_result.extend(self.mask_image_r)

        self.results_dict = self.results_dict._replace(name=ret_path, result=ret_result)
        return self.results_dict._asdict()

    def set_input(self, data):
        ''' must use set_device in tensor '''
        self.cond_image = self.set_device(data.get('cond_image'))
        self.gt_image = self.set_device(data.get('gt_image'))
        self.mask = self.set_device(data.get('mask'))
        self.mask_image = data.get('mask_image')
        self.path = data['path']
        self.batch_size = len(data['path'])

    def set_input_phase(self, data, phase):
        self.batch_size = len(data['forget']['path'])

        # forget images
        self.path_f = data['forget']['path']
        self.cond_image_f = self.set_device(data['forget'].get('cond_image'))
        self.gt_image_f = self.set_device(data['forget'].get('gt_image'))
        self.mask_f = self.set_device(data['forget'].get('mask'))
        self.mask_image_f = data['forget'].get('mask_image')

        # retain images
        self.path_r = data['retain']['path']
        self.cond_image_r = self.set_device(data['retain'].get('cond_image'))
        self.gt_image_r = self.set_device(data['retain'].get('gt_image'))
        self.mask_r = self.set_device(data['retain'].get('mask'))
        self.mask_image_r = data['retain'].get('mask_image')

        if phase == 'train':
            # gaussian images
            self.gt_image_g = self.set_device(data['gaussian'].get('gt_image'))
            # keep same mask with forget
            self.mask_g = self.set_device(data['forget'].get('mask'))
            self.cond_image_g = self.set_device(
                self.gt_image_g*(1. - self.mask_g) + self.mask_g*torch.randn_like(self.gt_image_g))
            self.mask_image_g = data['gaussian'].get('gt_image')*(1. - data['forget'].get('mask')) + data['forget'].get('mask')
        
    def train_step(self):
        self.netG.train()
        self.train_metrics.reset()
        batch_num = 0
        data_list = []
        for train_data in tqdm.tqdm(self.phase_loader):
            self.set_input_phase(train_data, phase='train')
            self.optimizer.zero_grad()

            if self.method['method_type'] == 'ours':
                def forget_closure():
                    with torch.cuda.amp.autocast(), torch.enable_grad():
                        self.optimizer.zero_grad()
                        if self.method['is_encoder']:
                            loss, noise_hat, hidden_list = self.netG(
                                self.gt_image_f, self.cond_image_f, mask=self.mask_f)
                            if self.method['is_clone']:
                                with torch.no_grad(): 
                                    loss_clone, noise_hat_clone, hidden_list_clone = self.netG_clone(
                                        self.gt_image_g, self.cond_image_g, mask=self.mask_g)
                            else:
                                with torch.no_grad():
                                    loss_clone, noise_hat_clone, hidden_list_clone = self.netG_clone(
                                        self.gt_image_g, self.cond_image_g, mask=self.mask_g)
                            loss = self.criterion(hidden_list, hidden_list_clone)
                        else:
                            loss, noise_hat, hidden_list = self.netG(
                                self.gt_image_f, self.cond_image_f, mask=self.mask_f)
                            if self.method['is_clone']:
                                with torch.no_grad():
                                    loss_clone, noise_hat_clone, hidden_list_clone = self.netG_clone(
                                        self.gt_image_g, self.cond_image_g, mask=self.mask_g)
                            else:
                                with torch.no_grad():
                                    loss_clone, noise_hat_clone, hidden_list_clone = self.netG(
                                        self.gt_image_g, self.cond_image_g, mask=self.mask_g)
                                loss = self.netG.loss_fn(noise_hat_clone, noise_hat)
                        if not math.isfinite(loss):
                            print("Loss is {}, stopping training".format(loss))
                            sys.exit(1)
                        loss.backward()
                    return loss
                
                def retain_closure():
                    with torch.cuda.amp.autocast(), torch.enable_grad():
                        self.optimizer.zero_grad()
                        if self.method['is_encoder']:
                            loss, noise_hat, hidden_list = self.netG(
                                self.gt_image_r, self.cond_image_r, mask=self.mask_r)
                            if self.method['is_clone']:
                                with torch.no_grad():
                                    loss_clone, noise_hat_clone, hidden_list_clone = self.netG_clone(
                                        self.gt_image_r, self.cond_image_r, mask=self.mask_r)
                            else:
                                with torch.no_grad():
                                    loss_clone, noise_hat_clone, hidden_list_clone = self.netG(
                                        self.gt_image_r, self.cond_image_r, mask=self.mask_r)
                            loss = self.criterion(hidden_list, hidden_list_clone)
                        else:
                            loss, noise_hat, hidden_list = self.netG(
                                self.gt_image_r, self.cond_image_r, mask=self.mask_r)
                            if self.method['is_clone']:
                                with torch.no_grad():
                                    loss_clone, noise_hat_clone, hidden_list_clone = self.netG_clone(
                                        self.gt_image_r, self.cond_image_r, mask=self.mask_r)
                            else:
                                with torch.no_grad():
                                    loss_clone, noise_hat_clone, hidden_list_clone = self.netG(
                                        self.gt_image_r, self.cond_image_r, mask=self.mask_r)
                            loss = self.netG.loss_fn(noise_hat_clone, noise_hat)
                        if not math.isfinite(loss):
                            print("Loss is {}, stopping training".format(loss))
                            sys.exit(1)
                        loss.backward()
                    return loss
                    
                loss_forget, loss_retain = self.optimizer.step(
                    forget_closure=forget_closure, retain_closure=retain_closure, 
                    mode = self.method['opt_type'], g_constraint=self.method['opt_g'])
                
            elif self.method['method_type'] in ['max_loss', 'retain_label','noise_label']:
                with torch.cuda.amp.autocast():
                    with torch.enable_grad():
                        self.optimizer.zero_grad()
                        
                        loss_f, noise_hat_f, hidden_list_f = self.netG(
                            self.gt_image_f, self.cond_image_f, mask=self.mask_f)
                        loss_r, noise_hat_r, hidden_list_r = self.netG(
                            self.gt_image_r, self.cond_image_r, mask=self.mask_r)

                        with torch.no_grad():
                            if self.method['method_type'] == 'max_loss':
                                _, _, hidden_list_clone_f = self.netG_clone(
                                    self.gt_image_f, self.cond_image_g, mask=self.mask_f)
                            elif self.method['method_type'] == 'retain_label':
                                _, _, hidden_list_clone_f = self.netG_clone(
                                    self.gt_image_r, self.cond_image_g, mask=self.mask_f)
                            elif self.method['method_type'] == 'noise_label':
                                _, _, hidden_list_clone_f = self.netG_clone(
                                    self.gt_image_g, self.cond_image_g, mask=self.mask_f)
                            _, _, hidden_list_clone_r = self.netG_clone(
                                self.gt_image_r, self.cond_image_r, mask=self.mask_r)
                        loss_forget = self.criterion(hidden_list_f, hidden_list_clone_f)
                        loss_retain = self.criterion(hidden_list_r, hidden_list_clone_r)
                        
                        if self.method['method_type'] == 'max_loss':
                            loss = loss_retain - loss_forget * 0.25
                        elif self.method['method_type'] in ['retain_label', 'noise_label']:
                            loss = loss_retain + loss_forget * 0.25
                            
                if not math.isfinite(loss):
                    print("Loss is {}, stopping training".format(loss))
                    sys.exit(1)
                loss.backward()
                self.optimizer.step()
                
            elif self.method['method_type'] == 'iclr':
                with torch.cuda.amp.autocast():
                    with torch.enable_grad():
                        self.optimizer.zero_grad()
                        loss_f, noise_hat_f, hidden_list_f = self.netG(
                            self.gt_image_f, self.cond_image_f, mask=self.mask_f)
                        loss_r, noise_hat_r, hidden_list_r = self.netG(
                            self.gt_image_r, self.cond_image_r, mask=self.mask_r)
                        with torch.no_grad():
                            loss_f_clone, noise_hat_f_clone, hidden_list_f_clone = self.netG_clone(
                                self.gt_image_g, self.cond_image_g, mask=self.mask_g)
                            loss_r_clone, noise_hat_r_clone, hidden_list_r_clone = self.netG_clone(
                                self.gt_image_r, self.cond_image_r, mask=self.mask_r)
                        loss_forget = self.criterion(hidden_list_f, hidden_list_f_clone)
                        loss_retain = self.criterion(hidden_list_r, hidden_list_r_clone)

                        loss = loss_retain + loss_forget * 0.25
                        
                if not math.isfinite(loss):
                    print("Loss is {}, stopping training".format(loss))
                    sys.exit(1)
                loss.backward()
                self.optimizer.step()

            self.iter += self.batch_size
            self.writer.set_iter(self.epoch, self.iter, phase='train')
            self.train_metrics.update('mse_forget', loss_forget.item())
            self.train_metrics.update('mse_retain', loss_retain.item())
            
            if batch_num % 15 == 0:
                data_list.append({'f1': loss_forget.detach().cpu().numpy(), 'f2': loss_retain.detach().cpu().numpy()})
                self.logger.info('Epoch {:.0f}---forget loss {:.4f}---retain loss {:.4f}'.format(self.epoch, loss_forget, loss_retain))
                
            # batch_length = self.gt_image_g.shape[0]
            self.epoch_forget_loss += loss_forget
            self.epoch_retain_loss += loss_retain
            batch_num += 1
            '''
            self.train_metrics.update(self.loss_fn.__name__, loss_forget.item())
            self.train_metrics.update(self.loss_fn.__name__, loss_retain.item())
            '''
            
            #if self.iter % self.opt['train']['log_iter'] == 0:
            if batch_num % 50 == 0:
                for key, value in self.train_metrics.result().items():
                    self.logger.info('{:5s}: {}\t'.format(str(key), value))
                    self.writer.add_scalar(key, value)
                for key, value in self.get_current_visuals().items():
                    self.writer.add_images(key, value)
            if self.ema_scheduler is not None:
                if self.iter > self.ema_scheduler['ema_start'] and self.iter % self.ema_scheduler['ema_iter'] == 0:
                    self.EMA.update_model_average(self.netG_EMA, self.netG)
                    
        self.epoch_forget_loss = self.epoch_forget_loss / batch_num
        self.epoch_retain_loss = self.epoch_retain_loss / batch_num

        for scheduler in self.schedulers:
            scheduler.step()
            
        df_loss = pd.DataFrame(data_list)
        df_loss.to_csv(f't-sne-diffusion/{self.epoch}.csv', index=False)
        return self.train_metrics.result()
    
    def val_step(self):
        self.netG.eval()
        self.val_metrics.reset()
        with torch.no_grad():
            for val_data in tqdm.tqdm(self.val_loader):
                self.set_input_phase(val_data, phase='val')
                if self.opt['distributed']:
                    if self.task in ['inpainting','uncropping']:
                        self.output_f, self.visuals_f = self.netG.module.restoration(
                            self.cond_image_f, y_t=self.cond_image_f, 
                            y_0=self.gt_image_f, mask=self.mask_f, sample_num=self.sample_num)
                        self.output_r, self.visuals_r = self.netG.module.restoration(
                            self.cond_image_r, y_t=self.cond_image_r, 
                            y_0=self.gt_image_r, mask=self.mask_r, sample_num=self.sample_num)
                    else:
                        self.output_f, self.visuals_f = self.netG.module.restoration(
                            self.cond_image_f, sample_num=self.sample_num)
                        self.output_r, self.visuals_r = self.netG.module.restoration(
                            self.cond_image_r, sample_num=self.sample_num)
                else:
                    if self.task in ['inpainting','uncropping']:
                        self.output_f, self.visuals_f = self.netG.restoration(
                            self.cond_image_f, y_t=self.cond_image_f, 
                            y_0=self.gt_image_f, mask=self.mask_f, sample_num=self.sample_num)
                        self.output_r, self.visuals_r = self.netG.restoration(
                            self.cond_image_r, y_t=self.cond_image_r, 
                            y_0=self.gt_image_r, mask=self.mask_r, sample_num=self.sample_num)
                    else:
                        self.output_f, self.visuals_f = self.netG.restoration(
                            self.cond_image_f, sample_num=self.sample_num)
                        self.output_r, self.visuals_r = self.netG.restoration(
                            self.cond_image_r, sample_num=self.sample_num)
                    
                self.iter += self.batch_size
                self.writer.set_iter(self.epoch, self.iter, phase='val')

                for met in self.metrics:
                    key = met.__name__
                    value_f = met(self.gt_image_f, self.output_f)
                    value_r = met(self.gt_image_r, self.output_r)
                    self.val_metrics.update('mse_forget', value_f)
                    self.writer.add_scalar('mse_forget', value_f)
                    
                    self.val_metrics.update('mse_retain', value_r)
                    self.writer.add_scalar('mse_retain', value_r)
                    '''
                    self.val_metrics.update(key, value_f)
                    self.writer.add_scalar(key, value_f)
                    '''
                for key, value in self.get_current_visuals(phase='val').items():
                    self.writer.add_images(key, value)
                self.writer.save_images(self.save_current_results())

        return self.val_metrics.result()

    def test(self):
        self.netG.eval()
        self.test_metrics.reset()
        with torch.no_grad():
            for phase_data in tqdm.tqdm(self.phase_loader):
                self.set_input_phase(phase_data, phase='test')
                if self.opt['distributed']:
                    if self.task in ['inpainting','uncropping']:
                        self.output_f, self.visuals_f = self.netG.module.restoration(
                            self.cond_image_f, y_t=self.cond_image_f, 
                            y_0=self.gt_image_f, mask=self.mask_f, sample_num=self.sample_num)
                        self.output_r, self.visuals_r = self.netG.module.restoration(
                            self.cond_image_r, y_t=self.cond_image_r, 
                            y_0=self.gt_image_r, mask=self.mask_r, sample_num=self.sample_num)
                    else:
                        self.output_f, self.visuals_f = self.netG.module.restoration(
                            self.cond_image_f, sample_num=self.sample_num)
                        self.output_r, self.visuals_r = self.netG.module.restoration(
                            self.cond_image_r, sample_num=self.sample_num)
                else:
                    if self.task in ['inpainting','uncropping']:
                        print(f'task is : {self.task}')
                        self.output_f, self.visuals_f = self.netG.restoration(
                            self.cond_image_f, y_t=self.cond_image_f, 
                            y_0=self.gt_image_f, mask=self.mask_f, sample_num=self.sample_num)
                        self.output_r, self.visuals_r = self.netG.restoration(
                            self.cond_image_r, y_t=self.cond_image_r, 
                            y_0=self.gt_image_r, mask=self.mask_r, sample_num=self.sample_num)
                    else:
                        self.output_f, self.visuals_f = self.netG.restoration(
                            self.cond_image_f, sample_num=self.sample_num)
                        self.output_r, self.visuals_r = self.netG.restoration(
                            self.cond_image_r, sample_num=self.sample_num)
                        
                self.iter += self.batch_size
                self.writer.set_iter(self.epoch, self.iter, phase='test')
                for met in self.metrics:
                    key = met.__name__
                    value_f = met(self.gt_image_f, self.output_f)
                    value_r = met(self.gt_image_r, self.output_r)
                    self.val_metrics.update('mse_forget', value_f)
                    self.writer.add_scalar('mse_forget', value_f)
                    self.val_metrics.update('mse_retain', value_r)
                    self.writer.add_scalar('mse_retain', value_r)
                    '''
                    self.val_metrics.update(key, value_f)
                    self.writer.add_scalar(key, value_f)
                    '''
                for key, value in self.get_current_visuals(phase='test').items():
                    self.writer.add_images(key, value)
                self.writer.save_images(self.save_current_results())
        
        test_log = self.test_metrics.result()
        ''' save logged informations into log dict ''' 
        test_log.update({'epoch': self.epoch, 'iters': self.iter})

        ''' print logged informations to the screen and tensorboard ''' 
        for key, value in test_log.items():
            self.logger.info('{:5s}: {}\t'.format(str(key), value))

    def load_networks(self):
        """ save pretrained model and training state, which only do on GPU 0. """
        if self.opt['distributed']:
            netG_label = self.netG.module.__class__.__name__
        else:
            netG_label = self.netG.__class__.__name__
        self.load_network(network=self.netG, network_label=netG_label, strict=False)
        if self.ema_scheduler is not None:
            self.load_network(network=self.netG_EMA, network_label=netG_label+'_ema', strict=False)

    def save_everything(self):
        """ load pretrained model and training state. """
        if self.opt['distributed']:
            netG_label = self.netG.module.__class__.__name__
        else:
            netG_label = self.netG.__class__.__name__
        self.save_network(network=self.netG, network_label=netG_label)
        if self.ema_scheduler is not None:
            self.save_network(network=self.netG_EMA, network_label=netG_label+'_ema')
        self.save_training_state()
