import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from torch.utils.data import DataLoader
from datasets.datasets_oneway import init_dataset
from models import *
from utils import save_checkpoint, mse_fn, per_element_rel_mse_fn, create_coordinate_grid, encode_coordinates
from visualization import visualize_results
import math
import time

class BaseTrainer:
    def __init__(self, args, device, fig_train_save_path, param_save_path):
        self.args = args
        self.device = device
        self.fig_train_save_path = fig_train_save_path
        self.param_save_path = param_save_path
        
        self._init_data()
        self._init_mod_latent()
        self._init_modulator()
        self._init_model()
        self._init_training()
        
    def _init_data(self):
        self.train_dataset = init_dataset(self.args)
        
        if self.args.normalize:
            self.normalizer = self.train_dataset.get_normalizer()
        else:
            self.normalizer = None
            
        self.train_loader = DataLoader(self.train_dataset, batch_size=self.args.batch_size, shuffle=True)
        self.train_ts = create_coordinate_grid(self.args.seq_len1, self.args.seq_len2, self.device, self.args.dataset)
    
    def _init_mod_latent(self):
        if self.args.modulation == 'spatial':
            self.latents = nn.ParameterList([
                nn.Parameter(torch.zeros(self.args.latent_channel, self.args.latent_size, self.args.latent_size).to(self.device)) 
                for _ in range(self.args.ntrain)
            ]).to(self.device)
        else:
            self.latents = nn.ParameterList([
                nn.Parameter(torch.zeros(self.args.latent_dim).to(self.device)) 
                for _ in range(self.args.ntrain)
            ]).to(self.device)
        
        if self.args.inr == 'wire':
            self.args.hidden_dim = int(self.args.hidden_dim/np.sqrt(2))
        
        if self.args.modulation in ['shift', 'scale', 'spatial']:
            self.mod_dim = self.args.hidden_dim
        elif self.args.modulation == 'film':
            self.mod_dim = self.args.hidden_dim * 2
        elif self.args.modulation == 'gfm':
            self.mod_dim = self.args.n_fourier_bases * 2 
    
    def _init_model(self):
        raise NotImplementedError
        
    def _init_training(self):
        meta_params = list(self.Modulator.parameters()) + list(self.INR.parameters())
        
        self.optim_latent = optim.Adam([{'params': self.latents, 'lr': self.args.lr_inner}])
        self.optim_meta = optim.Adam([{'params': meta_params, 'lr': self.args.lr_outer}])
        
        if self.args.dataset in ['airfoil', 'pipe']:
            self.compute_loss = per_element_rel_mse_fn
        else:
            self.compute_loss = mse_fn
    
    def _init_modulator(self):
        
        if self.args.modulation == 'spatial':
            conv = nn.Conv2d(in_channels = self.args.latent_channel, 
                             out_channels = self.mod_dim * (self.args.num_layers+1),
                             kernel_size=3, padding=1).to(self.device)
            nn.init.zeros_(conv.bias)
            self.Modulator = conv
            
        else:
            layers = []
            in_dim = self.args.latent_dim
            for i in range(self.args.num_mod_layers - 1):
                layers.append(nn.Linear(in_dim, self.args.mod_latent_dim))
                layers.append(nn.ReLU(inplace=True))
                in_dim = self.args.mod_latent_dim
            
            if self.args.inr == 'functa':
                layers.append(nn.Linear(self.args.mod_latent_dim, self.mod_dim * (self.args.num_layers+1)))
            else:
                layers.append(nn.Linear(self.args.mod_latent_dim, self.mod_dim * self.args.num_layers))
            self.Modulator = nn.Sequential(*layers).to(self.device)
            
    def train_step(self, target_data, data_idx):
        
        train_y = target_data.to(self.device)

        if self.args.sample_num >= len(data_idx):
            batch_sample_lst = random.sample(range(len(data_idx)), len(data_idx))
        else:
            batch_sample_lst = random.sample(range(len(data_idx)), self.args.sample_num)

        for j in batch_sample_lst:
            for _ in range(self.args.maxiter_inner):
                self.optim_latent.zero_grad()
                
                mod = self.get_modulation(self.latents[data_idx[j]])
                
                pred = self.INR(encode_coordinates(self.train_ts, self.args.n_fourier, self.args.inr), mod).reshape(self.args.seq_len1, self.args.seq_len2, self.args.out_dim)
                loss = self.compute_loss(pred.unsqueeze(0), train_y[j].unsqueeze(0))
                loss.backward()
                self.optim_latent.step()

        # Outer optimization
        self.optim_meta.zero_grad()
        loss_total = 0
        
        for j in range(len(data_idx)):
            mod = self.get_modulation(self.latents[data_idx[j]])
            pred = self.INR(encode_coordinates(self.train_ts, self.args.n_fourier, self.args.inr), mod).reshape(self.args.seq_len1, self.args.seq_len2, self.args.out_dim)
            loss_total += self.compute_loss(pred.unsqueeze(0), train_y[j].unsqueeze(0))
            
        loss_total.backward()
        self.optim_meta.step()
        
        return loss_total
    
    def visualize(self, vis_target_data, frame, normalizer):
        with torch.no_grad():
            self.Modulator.eval()
            self.INR.eval()
            self.latents.eval()
            
            train_y, vis_data_idx = vis_target_data
            train_y = train_y.to(self.device)
            num_vis = len(vis_data_idx)
            
            total_pred = torch.zeros(num_vis, self.args.seq_len1, self.args.seq_len2, self.args.out_dim)
            for fig_idx in range(num_vis):
                mod = self.get_modulation(self.latents[vis_data_idx[fig_idx]])
                pred = self.INR(encode_coordinates(self.train_ts, self.args.n_fourier, self.args.inr), mod).reshape(self.args.seq_len1, self.args.seq_len2, -1)
                total_pred[fig_idx] = pred
                
            data = (train_y.cpu())
            predictions = total_pred.cpu()
            visualize_results(self.args, frame, data, predictions, self.fig_train_save_path, normalizer)
            
            return total_pred
    
    def get_modulation(self, latent):
        mod = self.Modulator(latent)
        
        if self.args.modulation == 'spatial':
            mod = mod.unsqueeze(0)
            mod = torch.nn.functional.interpolate(mod, size=(self.args.seq_len1, self.args.seq_len2), mode='nearest')
            mod = mod.permute(0, 2, 3, 1).reshape(-1, self.mod_dim, self.args.num_layers+1)
            return mod
        
        if self.args.inr == 'functa':
            return mod.reshape(self.mod_dim, (self.args.num_layers+1))
        else:
            return mod.reshape(self.mod_dim, self.args.num_layers)
        
    def get_prediction(self, coord, mod):
        pred = self.INR(coord, mod)
        return pred.reshape(self.args.seq_len1, self.args.seq_len2, self.args.out_dim)
    
    def train(self):
        print(self.Modulator)
        print(self.INR)
        best_train_loss = float('inf')
        frame = 2
        vis_target_data = self.train_dataset.get_vis_data()
        
        for itr_outer in range(self.args.maxiter_outer):
            self.Modulator.train()
            self.INR.train()
            self.latents.train()
            
            train_loss_total = 0
            total_data_num = 0
            
            torch.cuda.reset_peak_memory_stats()
            for target_data, data_idx in self.train_loader:
                loss_total = self.train_step(target_data, data_idx)
                train_loss_total += loss_total
                total_data_num += len(data_idx)
                
            avg_loss = train_loss_total / total_data_num
            print(f'<Iter {itr_outer}> Loss [TOTAL]: {avg_loss:.4f}')
            
            is_train_best = avg_loss < best_train_loss
            if is_train_best:
                best_train_loss = avg_loss
                # Save checkpoint
                state = {
                    'INR': self.INR.state_dict(),
                    'Modulator': self.Modulator.state_dict(),
                    'Latents': self.latents.state_dict(),
                    'loss': loss_total.item()
                } 
                save_checkpoint(state, is_train_best, None, self.param_save_path, itr_outer, best_train_loss, None) 
                
            if (itr_outer+1) % 20 == 0:
                self.visualize(vis_target_data, frame, self.normalizer)
                frame += 2
            
                
                
class FunctaTrainer(BaseTrainer):
    def _init_model(self):
        self.INR = Functa(
            in_features=self.args.enc_dim,
            out_features=self.args.out_dim,
            num_hidden_layers=self.args.num_layers,
            hidden_features=self.args.hidden_dim,
            mod_type=self.args.modulation,
            mod_dim=self.mod_dim
        ).to(self.device)
        
        print(f"Modulation: {self.args.modulation}")
        print(f"Modulator: {self.Modulator}")
        print(f"INR: {self.INR}")
                        


class SIRENTrainer(BaseTrainer):
    def _init_model(self):
        if self.args.modulation == 'gfm':
            self.INR = INR_GFM(
                in_features=self.args.enc_dim,
                out_features=self.args.out_dim,
                num_hidden_layers=self.args.num_layers,
                hidden_features=self.args.hidden_dim,
                outermost_linear=True,
                nonlinearity='sine',
                scale=self.args.scale,
                weight_init=None,
                n_fourier_bases=self.args.n_fourier_bases,
                high_freq=self.args.high_freq,
                low_freq=self.args.low_freq,
                phi_dim=self.args.phi_dim,
                oneway=True
            ).to(self.device)
