import torch
import torch.nn as nn
import torch.optim as optim
import os
import numpy as np
import random
import argparse
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch.utils.data import DataLoader
from trainer_twoway import SIRENTrainer, FunctaTrainer
from trainer_twoway_3d import FunctaTrainer3D, SIRENTrainer3D
from torcheval.metrics import PeakSignalNoiseRatio

from datasets.datasets_twoway import init_dataset
from utils import *
from visualization import visualize_results
from matplotlib.colors import ListedColormap

myloss = LpLoss(size_average=False)


class TestTrainer(SIRENTrainer):
# class TestTrainer(FunctaTrainer):
# class TestTrainer(FunctaTrainer3D):
# class TestTrainer(SIRENTrainer3D):
    def __init__(self, args, device, fig_save_path):
        self.args = args
        self.device = device
        self.fig_save_path = fig_save_path
        self.frame = 0
        
        self._init_data()
        self._load_model()
        
    def _init_data(self):
    
        self.train_dataset, _ = init_dataset(self.args)
        self.train_loader = DataLoader(self.train_dataset, batch_size=1, shuffle=False)
        
        if self.args.normalize:
            self.x_normalizer, self.y_normalizer = self.train_dataset.get_normalizers()
        
        if self.args.inr == 'wire':
            self.args.hidden_dim = int(self.args.hidden_dim/np.sqrt(2))
        
        if self.args.modulation == 'shift':
            self.mod_dim = self.args.hidden_dim
        elif self.args.modulation == 'gfm':
            self.mod_dim = self.args.mod_dim 
            
        if self.args.dataset == 'fwi':
            self.args.seis_min, self.args.seis_max, self.args.vel_min, self.args.vel_max = self.train_dataset.get_minmax()
            
            
        
    def _load_model(self):
        super()._init_mod_latent()
        super()._init_modulator()
        super()._init_model() 
        ckpt = torch.load(self.args.ckpt_path)
    
        self.INR_obs.load_state_dict(ckpt['INR_obs'])
        self.INR_src.load_state_dict(ckpt['INR_src'])
        self.Modulator_obs.load_state_dict(ckpt['Modulator_obs'])
        self.Modulator_src.load_state_dict(ckpt['Modulator_src'])
        
        loaded_latents = ckpt['Latents']
        self.latents = loaded_latents
        
        for model in [self.Modulator_obs, self.Modulator_src, self.INR_obs, self.INR_src]:
            for param in model.parameters():
                param.requires_grad = False
        
    def test(self):
        print("Starting test...")
        total_err = 0
        total_l2_err = 0
        psnr_obs = PeakSignalNoiseRatio()
        psnr_src = PeakSignalNoiseRatio()
        total_psnr_obs, total_psnr_src, total_mse_obs, total_mse_src = 0, 0, 0, 0
        
        for train_obs_data, train_src_data, train_idx in self.train_loader:
            sample_psnr_obs, sample_psnr_src, sample_mse_obs, sample_mse_src = 0, 0, 0, 0
            i = train_idx[0].item()
            mod_obs = self.get_modulation(self.latents[str(i)], 'obs')
            mod_src = self.get_modulation(self.latents[str(i)], 'src')
            
            if self.args.dataset == 'ns3d_twoway':
                self.train_ts_obs = create_coordinate_grid_3d(self.args.seq_len1_obs, self.args.seq_len2_obs, self.args.seq_len3_obs, self.device, self.args.dataset)
                self.train_ts_src = create_coordinate_grid_3d(self.args.seq_len1_src, self.args.seq_len2_src, self.args.seq_len3_src, self.device, self.args.dataset)
                pred_obs = self.INR_obs(encode_coordinates(self.train_ts_obs, self.args.n_fourier, self.args.inr), mod_obs).reshape(self.args.seq_len1_obs, self.args.seq_len2_obs, self.args.seq_len3_obs, self.args.out_dim_obs)
                pred_src = self.INR_src(encode_coordinates(self.train_ts_src, self.args.n_fourier, self.args.inr), mod_src).reshape(self.args.seq_len1_src, self.args.seq_len2_src, self.args.seq_len3_src, self.args.out_dim_src)
            else:
                self.train_ts_obs = create_coordinate_grid(self.args.seq_len1_obs, self.args.seq_len2_obs, self.device, self.args.dataset)
                self.train_ts_src = create_coordinate_grid(self.args.seq_len1_src, self.args.seq_len2_src, self.device, self.args.dataset)
                pred_obs = self.INR_obs(encode_coordinates(self.train_ts_obs, self.args.n_fourier, self.args.inr), mod_obs).reshape(self.args.seq_len1_obs, self.args.seq_len2_obs, self.args.out_dim_obs)
                pred_src = self.INR_src(encode_coordinates(self.train_ts_src, self.args.n_fourier, self.args.inr), mod_src).reshape(self.args.seq_len1_src, self.args.seq_len2_src, self.args.out_dim_src)            
            
            if self.args.normalize:
                pred_obs = self.x_normalizer.decode(pred_obs.detach().cpu())[:,:,:,0]
                pred_src = self.x_normalizer.decode(pred_src.detach().cpu())[:,:,:,0]
                train_obs_data = self.x_normalizer.decode(train_obs_data[0].detach().cpu())[:,:,:,0]
                train_src_data = self.x_normalizer.decode(train_src_data[0].detach().cpu())[:,:,:,0]
            else:
                pred_obs = pred_obs.detach().cpu()
                pred_src = pred_src.detach().cpu()
                train_obs_data = train_obs_data[0].detach().cpu()
                train_src_data = train_src_data[0].detach().cpu()
            
            for obs_idx in range(pred_obs.shape[-1]):
                psnr_obs.update(pred_obs[:,:,obs_idx], train_obs_data[:,:,obs_idx])
                mse_obs = mse_fn(pred_obs[:,:,obs_idx], train_obs_data[:,:,obs_idx])
                sample_psnr_obs += psnr_obs.compute()
                sample_mse_obs += mse_obs.item()
                
            for src_idx in range(pred_src.shape[-1]):
                psnr_src.update(pred_src[:,:,src_idx], train_src_data[:,:,src_idx])
                mse_src = mse_fn(pred_src[:,:,src_idx], train_src_data[:,:,src_idx])
                sample_psnr_src += psnr_src.compute()
                sample_mse_src += mse_src.item()
                
            sample_psnr_obs /= pred_obs.shape[-1]
            sample_psnr_src /= pred_src.shape[-1]
            sample_mse_obs /= pred_obs.shape[-1]
            sample_mse_src /= pred_src.shape[-1]
                
            total_psnr_obs += sample_psnr_obs
            total_psnr_src += sample_psnr_src
            total_mse_obs += sample_mse_obs
            total_mse_src += sample_mse_src
            print(f"[{i}] PSNR_obs: {sample_psnr_obs:.6f}, PSNR_src: {sample_psnr_src:.6f}, MSE_obs: {sample_mse_obs:.6f}, MSE_src: {sample_mse_src:.6f}")
            

            if self.args.dataset == 'fwi':
                aspect = pred_obs.shape[2]/pred_obs.shape[1]
                pred_obs = pred_obs.unsqueeze(0) 
                pred_src = pred_src.unsqueeze(0)
                x_train = tonumpy_denormalize(pred_obs, self.args.seis_min, self.args.seis_max, exp=True)
                y_train = tonumpy_denormalize(pred_src, self.args.vel_min, self.args.vel_max, exp=False)

                plt.figure(figsize=(8, 8))
                plt.imshow(y_train[0, :, :, 0], 
                           vmax = np.max(y_train[0, :, :, 0]), 
                           vmin = np.min(y_train[0, :, :, 0])) 
                plt.xticks([], [])
                plt.yticks([], [])
                plt.tight_layout()
                plt.savefig(os.path.join(self.fig_save_path, f"vel_{i}_{self.args.modulation}.png"))
                plt.savefig(os.path.join(self.fig_save_path, f"vel_{i}_{self.args.modulation}.pdf"))
                plt.close()

                for seis_idx in range(5):
                    plt.figure(figsize=(8, 8))
                    plt.imshow(x_train[0, :, :, seis_idx], cmap='gray', vmin=-1e-5, vmax=1e-5, aspect=aspect)
                    plt.xticks([], [])
                    plt.yticks([], [])
                    plt.tight_layout()
                    plt.savefig(os.path.join(self.fig_save_path, f"seis_{i}_{seis_idx}_{self.args.modulation}.png"))
                    plt.savefig(os.path.join(self.fig_save_path, f"seis_{i}_{seis_idx}_{self.args.modulation}.pdf"))
                    plt.close()
                
                
        print(f"Total PSNR_obs: {total_psnr_obs/len(self.train_loader):.6f}, Total PSNR_src: {total_psnr_src/len(self.train_loader):.6f}, Total MSE_obs: {total_mse_obs/len(self.train_loader):.6f}, Total MSE_src: {total_mse_src/len(self.train_loader):.6f}")
                
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset', type=str, choices=['fwi', 'ns3d_twoway'], required=True)
    parser.add_argument('--modulation', type=str, choices=['shift', 'scale', 'film', 'gfm'], required=True)
    parser.add_argument('--ckpt_path', type=str, required=True)
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--config_name', type=str, required=True)
    parser.add_argument('--inr', type=str, choices=['siren', 'functa'], required=True)
    args = parser.parse_args()
    
    args = load_config(args)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    
    device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

    if args.dataset == 'ns3d_twoway':
        args.enc_dim = 3
    else:
        args.enc_dim = 2
    
    fig_save_path = os.path.join('vis_post', args.dataset, 
                                f'{args.inr}_{args.modulation}_{args.config_name}')
        
    os.makedirs(fig_save_path, exist_ok=True)
    
    test_trainer = TestTrainer(args, device, fig_save_path)
    test_trainer.test()

if __name__ == "__main__":
    main() 