import torch
import matplotlib.pyplot as plt
import os
import numpy as np
import math
import yaml
import pickle
mse_fn = torch.nn.MSELoss()
per_element_mse_fn = torch.nn.MSELoss(reduction="none")
per_element_bce_fn = torch.nn.BCEWithLogitsLoss(reduction="none")


class LpLoss(object):
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        h = 1.0 / (x.size()[1] - 1.0)

        all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)

        return diff_norms/y_norms

    def __call__(self, x, y):
        return self.rel(x, y)



def per_element_rel_mse_fn(x, y, reduction=True):
    num_examples = x.size()[0]
    diff_norms = torch.norm(
        x.reshape(num_examples, -1) - y.reshape(num_examples, -1), 2, 1
    )
    y_norms = torch.norm(y.reshape(num_examples, -1), 2, 1)

    return diff_norms / y_norms


def batch_mse_rel_fn(x1, x2):
    per_element_mse = per_element_rel_mse_fn(x1, x2)
    return per_element_mse.view(x1.shape[0], -1).mean(dim=1)


def get_save_paths(args): 
    if args.dataset == 'conv':
        fname_data = f'conv_{args.beta_start}to{args.beta_end}'
    elif args.dataset == 'helm':
        fname_data = f'helm_{args.a_start}to{args.a_end}'
    elif args.dataset == 'ns3d':
        fname_data = f'ns3d_{args.t1}to{args.t2}'
    elif args.dataset == 'fwi':
        fname_data = 'fwi'
    elif args.dataset == 'ks':
        fname_data = 'ks'
    elif args.dataset == 'ns3d_twoway':
        fname_data = 'ns3d_twoway'
    elif args.dataset == 'airfoil':
        fname_data = 'Airfoil'
    elif args.dataset == 'pipe':
        fname_data = 'Pipe'
    
    model_dir = f'{args.inr}_{args.modulation}'
    

    base_path = f'{fname_data}/{model_dir}/latent{args.latent_dim}_{args.config_name}'
    train_vis_path = os.path.join('./vis', base_path, 'TRAIN')

    param_path = os.path.join('./param', base_path)
    
    os.makedirs(train_vis_path, exist_ok=True)
    os.makedirs(param_path, exist_ok=True)
    
    return train_vis_path, param_path

def save_checkpoint(state, is_train_best, is_val_best, save_path, epoch=None, best_loss=None, best_val_loss=None):

    if is_train_best:
        save_file = os.path.join(save_path, f"model_best_train.ckpt")
        torch.save(state, save_file)
        print(f"saving model... Best train loss: {best_loss}")
        
    if epoch is not None and (epoch+1) % 200 == 0:
        save_file = os.path.join(save_path, f"model_{epoch+1}.ckpt")
        torch.save(state, save_file)

def create_coordinate_grid(seq_len1, seq_len2, device, dataset):

    
    if dataset in ['airfoil', 'pipe']:
        x_arr = torch.linspace(0, 1, steps=seq_len1)
        y_arr = torch.linspace(0, 1, steps=seq_len2)
        xx, yy = torch.meshgrid(x_arr, y_arr, indexing='ij')
        mesh = torch.stack([xx, yy], dim=2)
        x_coord, y_coord = mesh[:,:,0], mesh[:,:,1]
        train_ts = np.hstack((x_coord.flatten()[:, None], y_coord.flatten()[:, None]))

    elif dataset == 'conv':
        with open('./datasets/convection/convection_grid.pkl', 'rb') as f:
            train_ts = pickle.load(f)
    elif dataset == 'helm':
        with open('./datasets/helmholtz/helmholtz_grid.pkl', 'rb') as f:
            train_ts = pickle.load(f)
            train_ts = train_ts.numpy()
    else:
        train_ts1 = np.linspace(0, 0.04*(seq_len1-1), seq_len1)
        train_ts2 = np.linspace(0, 0.04*(seq_len2-1), seq_len2)
        
        if dataset == 'ks':
            x_coord, y_coord = np.meshgrid(train_ts2, train_ts1)
        else:
            x_coord, y_coord = np.meshgrid(train_ts1, train_ts2)
        
        train_ts = np.hstack((x_coord.flatten()[:, None], y_coord.flatten()[:, None]))
        
    train_ts = train_ts / train_ts[-1]
    
    if dataset == 'ks':
        train_ts[:, 0] = train_ts[:, 0] * 10
    elif dataset == 'conv':
        train_ts[:, 1] = train_ts[:, 1] * 2.5
    elif dataset == 'fwi':
        if seq_len1 != seq_len2:
            train_ts[:, 0] = train_ts[:, 0] * 14
    
    train_ts = torch.tensor(train_ts, dtype=torch.float32).to(device)
    train_ts.requires_grad = True
    
    return train_ts



def create_coordinate_grid_3d(seq_len1, seq_len2, seq_len3, device, dataset, interp=False):
    train_ts1 = np.linspace(0, 0.04*(seq_len1-1), seq_len1)
    train_ts2 = np.linspace(0, 0.04*(seq_len2-1), seq_len2)
    train_ts3 = np.linspace(0, 0.04*(seq_len3-1), seq_len3)
    
    if interp:
        train_ts3 = train_ts3[::2]
        

    x_coord, y_coord, z_coord = np.meshgrid(train_ts1, train_ts2, train_ts3, indexing='ij')
    train_ts = np.hstack((x_coord.flatten()[:, None], y_coord.flatten()[:, None], z_coord.flatten()[:, None]))
    train_ts = train_ts / train_ts[-1]
    
    if dataset == 'ns3d':
        train_ts[:, 0] = train_ts[:, 0] * 2
        train_ts[:, 1] = train_ts[:, 1] * 2
    if dataset == 'ns3d_twoway':
        train_ts[:, 0] = train_ts[:, 0] * 6
        train_ts[:, 1] = train_ts[:, 1] * 6
        
    train_ts = torch.tensor(train_ts, dtype=torch.float32).to(device)
    train_ts.requires_grad = True
    
    return train_ts

def encode_coordinates(coord, N, inr):
    if coord.dim() == 3:
        coord = coord.reshape(-1, 2)
    
    if inr != 'ffn':
        return coord
    else: 
        x_coord = coord[:, 0]
        y_coord = coord[:, 1]
        
        freqs = 2**torch.arange(N+1, device=coord.device) * math.pi
        
        enc_sin_x = torch.sin(freqs.view(-1, 1) * x_coord.view(1, -1)) 
        enc_sin_y = torch.sin(freqs.view(-1, 1) * y_coord.view(1, -1))
        enc_cos_x = torch.cos(freqs.view(-1, 1) * x_coord.view(1, -1))
        enc_cos_y = torch.cos(freqs.view(-1, 1) * y_coord.view(1, -1))
        
        encoded = torch.cat([enc_sin_x, enc_cos_x, enc_sin_y, enc_cos_y], dim=0)  
        return encoded.permute(1, 0)

def create_coordinate_grid_test(seq_len1, seq_len2, device, dataset, unseen_pct=0.5):
    
    if dataset == 'conv':
        with open('./datasets/convection/convection_grid.pkl', 'rb') as f:
            train_ts = pickle.load(f)
    elif dataset == 'helm':
        with open('./datasets/helmholtz/helmholtz_grid.pkl', 'rb') as f:
            train_ts = pickle.load(f)
            train_ts = train_ts.numpy()
    elif dataset == 'ks':
        train_ts1 = np.linspace(0, 0.04*(seq_len1-1), seq_len1)
        train_ts2 = np.linspace(0, 0.04*(seq_len2-1), seq_len2)
        x_coord, y_coord = np.meshgrid(train_ts2, train_ts1)
        train_ts = np.hstack((x_coord.flatten()[:, None], y_coord.flatten()[:, None]))

    
    train_ts = train_ts / train_ts[-1]
    
    if dataset == 'conv':
        train_ts = train_ts[np.where(train_ts[:,1]<=unseen_pct)]
        train_ts[:, 1] = train_ts[:, 1] * 2.5

    if dataset == 'ks':
        train_ts = train_ts[np.where(train_ts[:,1]<unseen_pct)]
        train_ts[:, 0] = train_ts[:, 0] * 10
    
    train_ts = torch.tensor(train_ts, dtype=torch.float32).to(device)
    train_ts.requires_grad = True
    
    return train_ts


class UnitGaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(UnitGaussianNormalizer, self).__init__()

        self.mean = torch.mean(x, 0)
        self.std = torch.std(x, 0)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x):
        x = (x * (self.std + self.eps)) + self.mean
        return x

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()


def load_config(args):
    config_path = os.path.join('config', f'{args.config_name}.yaml')
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    for category in config:
        for key, value in config[category].items():
            setattr(args, key, value)
            
    args.config = config
    return args


class RangeNormalizer(object):
    def __init__(self, x, low=0.0, high=1.0):
        super(RangeNormalizer, self).__init__()
        mymin = torch.min(x, 0)[0].view(-1)
        mymax = torch.max(x, 0)[0].view(-1)

        self.a = (high - low)/(mymax - mymin)
        self.b = -self.a*mymax + high
        


    def encode(self, x):
        s = x.size()
        x = x.reshape(s[0], -1)
        x = self.a*x + self.b
        x = x.view(s)
        return x

    def decode(self, x):
        s = x.size()
        x = x.view(s[0], -1)
        x = (x - self.b)/self.a
        x = x.view(s)
        return x
    

class AirfoilRangeNormalizer(object):
    def __init__(self, x):
        super(AirfoilRangeNormalizer, self).__init__()

        if x.dim() == 4:
            self.mymin = torch.min(x.view(-1,2), 0)[0]
            self.mymax = torch.max(x.view(-1,2), 0)[0]
        elif x.dim() == 3:
            self.mymin = torch.min(x)
            self.mymax = torch.max(x) 

    def encode(self, x):
        return (x - self.mymin) / (self.mymax - self.mymin)

    def decode(self, x):
        return x * (self.mymax - self.mymin) + self.mymin
    
# normalization, Gaussian
class GaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(GaussianNormalizer, self).__init__()

        self.mean = torch.mean(x)
        self.std = torch.std(x)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        x = (x * (self.std + self.eps)) + self.mean
        return x

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()

def log_transform(data, k=1, c=0):
    return (np.log1p(np.abs(k * data) + c)) * np.sign(data)

def log_transform_tensor(data, k=1, c=0):
    return (torch.log1p(torch.abs(k * data) + c)) * torch.sign(data)

def exp_transform(data, k=1, c=0):
    return (np.expm1(np.abs(data)) - c) * np.sign(data) / k

class LogTransform(object):
    def __init__(self, k=1, c=0):
        self.k = k
        self.c = c

    def __call__(self, data):
        return log_transform(data, k=self.k, c=self.c)
    
class MinMaxNormalize(object):
    def __init__(self, datamin, datamax, scale=2):
        self.datamin = datamin
        self.datamax = datamax
        self.scale = scale

    def __call__(self, vid):
        return minmax_normalize(vid, self.datamin, self.datamax, self.scale)
    
def minmax_normalize(vid, vmin, vmax, scale=2):
    vid -= vmin
    vid /= (vmax - vmin)
    return (vid - 0.5) * 2 if scale == 2 else vid
    
def minmax_denormalize(vid, vmin, vmax, scale=2):
    if scale == 2:
        vid = vid / 2 + 0.5
    return vid * (vmax - vmin) + vmin

def tonumpy_denormalize(vid, vmin, vmax, exp=True, k=1, c=0, scale=2):
    if exp:
        vmin = log_transform(vmin, k=k, c=c) 
        vmax = log_transform(vmax, k=k, c=c) 
    vid = minmax_denormalize(vid.cpu().numpy(), vmin, vmax, scale)
    return exp_transform(vid, k=k, c=c) if exp else vid