# ---------------------------------------------------------------
# This file has been modified from following sources: 
# Source:
# 1. https://github.com/NVlabs/LSGM/blob/main/util/ema.py (NVIDIA License)
# 2. https://github.com/NVlabs/denoising-diffusion-gan/blob/main/train_ddgan.py (NVIDIA License)
# 3. https://github.com/nhartland/KL-divergence-estimators (MIT License)
# ---------------------------------------------------------------
import os
import json
import warnings
import numpy as np
import torch
from torch.optim import Optimizer
import torch.nn.functional as F
from torchvision.utils import save_image
import math
from torch.distributions import Beta
from data.dataset import build_boundary_distribution
from data.pytorch_fid.fid_score import calculate_fid_given_paths


# ------------------------
# Initialize Transport Maps
# ------------------------
def initialize(netG, optimizerG, src_dataset, trg_dataset, args, exp_path):
    device = args.device
    if args.init_num_iterations > 0:
        import ot
        for i in range(args.init_num_iterations):
            netG.zero_grad()
            x = src_dataset.sample()
            y = trg_dataset.sample()
            latent_z = torch.randn((args.batch_size, args.nz))

            M = ((x.reshape(args.batch_size, -1)[:, None, :] - y.reshape(args.batch_size, -1)[None,:,:])**2).mean(2)
            uniform = torch.ones(args.batch_size) 

            try: A = ot.sinkhorn(uniform, uniform, M, 0.01)
            except: A = ot.sinkhorn(uniform, uniform, M, 0.1)
            idx = torch.argmax(A, dim=1)
            y = y[idx]

            loss = args.tau * ((netG(x.to(device), latent_z.to(device)) - y.to(device))**2).reshape(args.batch_size, -1).sum(1)
            loss.mean().backward()
            optimizerG.step()

            if (i + 1) % 100 == 0:
                with torch.no_grad():
                    save_image(0.5 * netG(x, latent_z).detach().cpu() + 0.5, os.path.join(exp_path, f'init_{i+1}.png'))
    return netG


# ------------------------
# Select Phi_star
# ------------------------
def select_phi(name):
    if name == 'linear':
        def phi(x):
            return x
            
    elif name == 'kl':
        def phi(x):
            return torch.exp(x) - 1
    
    elif name == 'chi':
        def phi(x):
            y = F.relu(x+2)-2
            return 0.25 * y**2 + y
        
    elif name == 'softplus':
        def phi(x):
            return 2*F.softplus(x) - 2*F.softplus(0*x)
    else:
        raise NotImplementedError
    
    return phi


class EvalAdapted:
    def __init__(self, args):
        # get default configs
        args.train = False
        self.problem_name = args.problem_name
        self.batch_size = args.batch_size
        self.device = args.device
        self.nz = args.nz
        self.sample_path = f'train_logs/{args.problem_name}/{args.exp}/generated_samples'
        os.makedirs(self.sample_path, exist_ok=True)

        # source/target name
        if args.problem_name.find('_to_') != -1:
            self.source_data_name, self.target_data_name = args.problem_name.split('_to_')
        else:
            self.source_data_name = 'gaussian'
            self.target_data_name = args.problem_name

        # get test sampler
        self.source_sampler, self.target_sampler = build_boundary_distribution(args)

        # get fid path name
        self.backward = False
        if self.source_data_name not in ['gaussian', 'uniform']:
            self.fid_source_path_train = f'data/pytorch_fid/{self.source_data_name}_{args.image_size}_train.npy'
            self.fid_source_path_test = f'data/pytorch_fid/{self.source_data_name}_{args.image_size}_test.npy'
            self.backward = True
            self.sample_path2 = f'train_logs/{args.problem_name}/{args.exp}/generated_samples2'
            os.makedirs(self.sample_path2, exist_ok=True)
        self.fid_target_path_train = f'data/pytorch_fid/{self.target_data_name}_{args.image_size}_train.npy'
        self.fid_target_path_test = f'data/pytorch_fid/{self.target_data_name}_{args.image_size}_test.npy'

    def generate(self, netG, fwd_or_bwd):
        if 'f' == fwd_or_bwd:
            data = self.source_sampler.sample().to(self.device)
        else:
            data = self.target_sampler.sample().to(self.device)
        latent_z = torch.randn(self.batch_size, self.nz, device=self.device)
        generated_data = netG(data, latent_z)
        generated_data = (0.5*(generated_data+1)).detach().cpu()
        data = (0.5*(data+1)).detach().cpu()
        
        if 'f' == fwd_or_bwd:
            return data, generated_data
        else:
            return generated_data, data
    
    def calculate_fid(self, info, num=None):
        if self.source_data_name == 'gaussian' and num is None:
            print('Need to specify the number of samples when evaluating FIDs, for now, we use 50000 samples')
            num = 50000
        elif self.source_data_name == 'uniform' and num is None:
            print('Need to specify the number of samples when evaluating FIDs, for now, we use 50000 samples')
            num = 50000
        elif num is None:
            num = self.source_sampler.dataloader.dataset.__len__()
        
        num_iter = num // self.batch_size

        for i in range(num_iter):
            with torch.no_grad():
                _, generated_data = self.generate(info['netG1'], 'f')
                
                for j, x in enumerate(generated_data):
                    index = i * self.batch_size + j 
                    save_image(x, os.path.join(self.sample_path, f'{index}.jpg'))

        paths = [self.fid_target_path_train, self.sample_path]
        kwargs = {'batch_size': 100, 'device': self.device, 'dims': 2048}
        fid_target_train = calculate_fid_given_paths(paths=paths, **kwargs)            

        if self.backward:
            paths = [self.fid_target_path_test, self.sample_path]
            kwargs = {'batch_size': 100, 'device': self.device, 'dims': 2048}
            fid_target_test = calculate_fid_given_paths(paths=paths, **kwargs)

            if num is None:
                num = self.target_sampler.dataloader.dataset.__len__()

            for i in range(num_iter):
                with torch.no_grad():
                    generated_data, _ = self.generate(info['netG2'], 'b')
                    
                    for j, x in enumerate(generated_data):
                        index = i * self.batch_size + j 
                        save_image(x, os.path.join(self.sample_path2, f'{index}.jpg'))

            paths = [self.fid_source_path_train, self.sample_path2]
            kwargs = {'batch_size': 100, 'device': self.device, 'dims': 2048}
            fid_source_train = calculate_fid_given_paths(paths=paths, **kwargs)
            paths = [self.fid_source_path_test, self.sample_path2]
            kwargs = {'batch_size': 100, 'device': self.device, 'dims': 2048}
            fid_source_test = calculate_fid_given_paths(paths=paths, **kwargs)

            return fid_target_train, fid_target_test, fid_source_train, fid_source_test
        else:
            return fid_target_train, None, None, None


# ------------------------
# Sampler
# ------------------------
class Sampler:
    def __init__(self, args):
        self.time_sample = args.time_sample
        if 'beta' in self.time_sample:
            alpha, beta = float(self.time_sample.split('_')[1]), float(self.time_sample.split('_')[2])
            self.beta = Beta(torch.FloatTensor(args.batch_size*[alpha]), torch.FloatTensor(args.batch_size*[beta]))
        elif 'discrete' in self.time_sample:
            self.num_timesteps = int(self.time_sample.split('_')[1])
        elif 't' in self.time_sample:
            self.start_time = float(self.time_sample.split('t')[1])

    def sample_t(self, batch_size):
        if self.time_sample == 'uniform':
            return torch.rand(batch_size)
        elif 'beta' in self.time_sample:
            return self.beta.sample()
        elif 'discrete' in self.time_sample:
            return torch.randint(high=self.num_timesteps+1, size=(batch_size,)).float() / self.num_timesteps
        elif 't' in self.time_sample:
            return torch.rand(batch_size) * (1 - self.start_time) + self.start_time
        

    def __call__(self, t, x0, x1):
        expended_t = t.view(t.size(0),*(1,)*(len(x0.shape)-1))
        xt = (1-expended_t) * x0 + expended_t * x1
        return xt


# ------------------------
# EMA
# ------------------------
class EMA(Optimizer):
    def __init__(self, opt, ema_decay):
        '''
        EMA Codes adapted from https://github.com/NVlabs/LSGM/blob/main/util/ema.py
        '''
        self.ema_decay = ema_decay
        self.apply_ema = self.ema_decay > 0.
        self.optimizer = opt
        self.state = opt.state
        self.param_groups = opt.param_groups

    def step(self, *args, **kwargs):
        retval = self.optimizer.step(*args, **kwargs)

        # stop here if we are not applying EMA
        if not self.apply_ema:
            return retval

        ema, params = {}, {}
        for group in self.optimizer.param_groups:
            for i, p in enumerate(group['params']):
                if not p.requires_grad:
                    continue
                state = self.optimizer.state[p]

                # State initialization
                if 'ema' not in state:
                    state['ema'] = p.data.clone()

                if p.shape not in params:
                    params[p.shape] = {'idx': 0, 'data': []}
                    ema[p.shape] = []

                params[p.shape]['data'].append(p.data)
                ema[p.shape].append(state['ema'])

            for i in params:
                params[i]['data'] = torch.stack(params[i]['data'], dim=0)
                ema[i] = torch.stack(ema[i], dim=0)
                ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay)

            for p in group['params']:
                if not p.requires_grad:
                    continue
                idx = params[p.shape]['idx']
                self.optimizer.state[p]['ema'] = ema[p.shape][idx, :]
                params[p.shape]['idx'] += 1

        return retval

    def load_state_dict(self, state_dict):
        super(EMA, self).load_state_dict(state_dict)
        # load_state_dict loads the data to self.state and self.param_groups. We need to pass this data to
        # the underlying optimizer too.
        self.optimizer.state = self.state
        self.optimizer.param_groups = self.param_groups

    def swap_parameters_with_ema(self, store_params_in_ema):
        """ This function swaps parameters with their ema values. It records original parameters in the ema
        parameters, if store_params_in_ema is true."""

        # stop here if we are not applying EMA
        if not self.apply_ema:
            warnings.warn('swap_parameters_with_ema was called when there is no EMA weights.')
            return

        for group in self.optimizer.param_groups:
            for i, p in enumerate(group['params']):
                if not p.requires_grad:
                    continue
                ema = self.optimizer.state[p]['ema']
                if store_params_in_ema:
                    tmp = p.data.detach()
                    p.data = ema.detach()
                    self.optimizer.state[p]['ema'] = tmp
                else:
                    p.data = ema.detach()


class Logger:
    def __init__(self, args, evaltool):
        self.exp_path = f'./train_logs/{args.problem_name}/{args.exp}'
        os.makedirs(self.exp_path, exist_ok=True)
        jsonstr = json.dumps(args.__dict__, indent=4)
        with open(os.path.join(self.exp_path, 'config.json'), 'w') as f:
            f.write(jsonstr)
        with open(os.path.join(self.exp_path, 'log.txt'), 'w') as f:
            f.write("Start Training")
            f.write('\n')

        self.use_ema = args.use_ema
        self.print_every = args.print_every
        self.save_image_every = args.save_image_every
        self.save_ckpt_every = args.save_ckpt_every
        self.fid_every = args.fid_every
        self.iter = 1
        self.evaltool = evaltool

    def step(self):
        self.iter += 1
        
    def __call__(self, text):
        if self.iter % self.print_every == 0:
            with open(os.path.join(self.exp_path, 'log.txt'), 'a') as f:
                f.write(text)
                f.write('\n')

    def save_image(self, info):
        if self.iter % self.save_image_every == 0:
            real_source, generated_target = self.evaltool.generate(info['netG1'], 'f')
            generated_source, real_target = self.evaltool.generate(info['netG2'], 'b')
            save_image(real_source, os.path.join(self.exp_path, f'iter_{self.iter}_real_source.png'))
            save_image(generated_target, os.path.join(self.exp_path, f'iter_{self.iter}_generated_target.png'))
            save_image(generated_source, os.path.join(self.exp_path, f'iter_{self.iter}_generated_source.png'))
            save_image(real_target, os.path.join(self.exp_path, f'iter_{self.iter}_real_target.png'))

    def swap_net(self, info):
        if self.use_ema:
            for key, item in info.items():
                if 'optimizerG' in key:
                    item.swap_parameters_with_ema(store_params_in_ema=True)

    def save_ckpt(self, info):
        if self.iter % self.save_ckpt_every == 0:
            self.swap_net(info)
            for key, item in info.items():
                if 'net' in key:
                    torch.save(item.state_dict(), os.path.join(self.exp_path, f'{key}_{self.iter}.pth'))
            self.swap_net(info)

    def calculate_fid(self, info):
        if self.iter % self.fid_every == 0:
            self.swap_net(info)
            fid_target_train, fid_target_test, fid_source_train, fid_source_test = self.evaltool.calculate_fid(info)
            self.swap_net(info)
            text = f'FID_target: {fid_target_train}, {fid_target_test}, FID_source: {fid_source_train}, {fid_source_test}, '
            self.__call__(text)
            print(text)


