import torch
import torch.nn as nn
import torch.optim as optim
import os
import numpy as np
import random
import argparse
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch.utils.data import DataLoader
from trainer_oneway import SIRENTrainer, FunctaTrainer
from trainer_oneway_3d import FunctaTrainer3D, SIRENTrainer3D
from torcheval.metrics import PeakSignalNoiseRatio

from datasets.datasets_oneway import init_dataset
from utils import *
from visualization import visualize_results

myloss = LpLoss(size_average=False)


class TestTrainer(SIRENTrainer):
# class TestTrainer(FunctaTrainer):
# class TestTrainer(FunctaTrainer3D):
# class TestTrainer(SIRENTrainer3D):
    def __init__(self, args, device, fig_save_path):
        self.args = args
        self.device = device
        self.fig_save_path = fig_save_path
        self.frame = 0
        
        self._init_data()
        self._load_model()
        
    def _init_data(self):
    
        self.train_dataset = init_dataset(self.args)
        self.train_loader = DataLoader(self.train_dataset, batch_size=1, shuffle=False)
        
        if self.args.normalize:
            self.x_normalizer = self.train_dataset.get_normalizer()
        
        if self.args.inr == 'wire':
            self.args.hidden_dim = int(self.args.hidden_dim/np.sqrt(2))
        
        if self.args.modulation == 'shift':
            self.mod_dim = self.args.hidden_dim
        elif self.args.modulation == 'gfm':
            self.mod_dim = self.args.mod_dim 
            
        
    def _load_model(self):
        super()._init_mod_latent()
        super()._init_modulator()
        super()._init_model() 
        ckpt = torch.load(self.args.ckpt_path)
    
        self.INR.load_state_dict(ckpt['INR'])
        self.Modulator.load_state_dict(ckpt['Modulator'])
        
        loaded_latents = ckpt['Latents']
        self.latents = loaded_latents
        
        for model in [self.Modulator, self.INR]:
            for param in model.parameters():
                param.requires_grad = False
        
    def test(self):
        print("Starting test...")
        total_err = 0
        total_l2_err = 0
        psnr = PeakSignalNoiseRatio()
        if self.args.dataset == 'helm':
            idx_lst = []
            for a1 in range(self.args.a_start*2, self.args.a_end*2 + 1):
                for a2 in range(self.args.a_start*2, self.args.a_end*2 + 1):
                    idx_lst.append((a1, a2))
        
        total_psnr, total_mse = 0, 0
        
        for train_data, train_idx in self.train_loader:
            i = train_idx[0].item()
            mod = self.get_modulation(self.latents[str(i)])
            
            if self.args.dataset == 'ns3d':
                self.test_ts = create_coordinate_grid_3d(self.args.seq_len1, self.args.seq_len2, self.args.seq_len3, self.device, self.args.dataset)
                pred = self.INR(self.test_ts, mod).reshape(self.args.seq_len1, self.args.seq_len2, self.args.seq_len3, self.args.out_dim)
            else:
                self.test_ts = create_coordinate_grid(self.args.seq_len1, self.args.seq_len2, self.device, self.args.dataset)
                pred = self.INR(encode_coordinates(self.test_ts, self.args.n_fourier, self.args.inr), mod).reshape(self.args.seq_len1, self.args.seq_len2, self.args.out_dim)
            
            if self.args.normalize:
                pred = self.x_normalizer.decode(pred.detach().cpu())
                train_data = self.x_normalizer.decode(train_data[0].detach().cpu())
            else:
                pred = pred.detach().cpu()
                train_data = train_data[0].detach().cpu()
            
            psnr.update(pred[:,:,0], train_data[:,:,0])
            mse = mse_fn(pred[:,:,0], train_data[:,:,0])
            total_psnr += psnr.compute()
            total_mse += mse.item()
            print(f"[{i}] PSNR: {psnr.compute():.6f}, MSE: {mse.item():.6f}")
            

            if self.args.dataset == 'conv':
                fig = plt.figure(figsize = (16, 10))
                ax = fig.add_subplot(111)

                figfig = ax.imshow(pred[:,:,0].detach().cpu().numpy().T, interpolation='nearest', 
                                   vmin=0, vmax=2, cmap='rainbow', origin='lower', aspect='auto',
                                   extent=[0, 1, 0, 2*np.pi])

                divider = make_axes_locatable(ax)
                cax = divider.append_axes("right", size="5%", pad=0.10)
                cbar = fig.colorbar(figfig, cax=cax)
                cbar.ax.tick_params(labelsize=35)


                ax.set_xlabel('t', fontweight='bold', size=60)
                ax.set_ylabel('x', fontweight='bold', size=60)

                ax.tick_params(labelsize=50)
                
                plt.tight_layout()
                fig.tight_layout()

                plt.savefig(f'{self.fig_save_path}/test_conv_{i+1}_{self.args.modulation}.png')
                plt.close()
                
            elif self.args.dataset == 'helm':
                
                plt.figure(figsize = (20, 16))
                plt.imshow(pred[:,:,0].detach().cpu().numpy().T, interpolation='nearest', vmin=-1, vmax=1, cmap='rainbow', origin='lower', aspect='auto', extent=[-1, 1, -1, 1])
                plt.xlabel('x', fontweight='bold', size=100)
                plt.ylabel('y', fontweight='bold', size=100)
                plt.tick_params(labelsize=50)
                ax = plt.gca() 
                ax.set_xticks(np.arange(-1, 1.01, 0.5))
                ax.set_yticks(np.arange(-1, 1.01, 0.5))
                cbar = plt.colorbar(fraction=0.046, pad=0.01)
                cbar.ax.tick_params(labelsize=50)  
                plt.tight_layout()
                plt.savefig('test.png')
                cur_idx = idx_lst[i]
                plt.savefig(f'{self.fig_save_path}/test_helm{self.args.a_start}to{self.args.a_end}_{cur_idx[0]/2}_{cur_idx[1]/2}_{self.args.modulation}.png')
                plt.close()

            elif self.args.dataset == 'ks':
                vmax, vmin = torch.max(train_data), torch.min(train_data)
                fig = plt.figure(figsize = (16,10))
                ax = fig.add_subplot(111)
                im = ax.imshow(pred[:,:,0].detach().cpu().numpy().T,
                           origin="lower",
                           aspect="auto",
                           cmap="inferno",
                           vmin=vmin,
                           vmax=vmax)

                ax.set_xticks([],[])
                ax.set_yticks([],[])

                plt.tight_layout()
                plt.savefig(f'{self.fig_save_path}/test_ks_{i+1}_{self.args.modulation}.png')
                plt.savefig(f'{self.fig_save_path}/test_ks_{i+1}_{self.args.modulation}.pdf')
                
                plt.close()

                
        print(f"Total PSNR: {total_psnr/len(self.train_loader):.6f}, Total MSE: {total_mse/len(self.train_loader):.6f}")
                
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, choices=['conv', 'helm', 'ks', 'ns3d'], required=True)
    parser.add_argument('--modulation', type=str, choices=['shift', 'scale', 'film', 'gfm', 'spatial'], required=True)
    parser.add_argument('--ckpt_path', type=str, required=True)
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--config_name', type=str, required=True)
    parser.add_argument('--inr', type=str, choices=['siren', 'functa'], required=True)
    args = parser.parse_args()
    
    args = load_config(args)

    
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    
    device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

    if args.dataset == 'ns3d':
        args.enc_dim = 3
    else:
        args.enc_dim = 2
    
    fig_save_path = os.path.join('vis_post', args.dataset, 
                                f'{args.inr}_{args.modulation}_{args.config_name}')
    if args.dataset == 'conv':
        fig_save_path += f'_{args.beta_start}to{args.beta_end}'
    elif args.dataset == 'helm':
        fig_save_path += f'_{args.a_start}to{args.a_end}'
        
    os.makedirs(fig_save_path, exist_ok=True)
    
    test_trainer = TestTrainer(args, device, fig_save_path)
    test_trainer.test()

if __name__ == "__main__":
    main() 