import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from torch.utils.data import DataLoader
from datasets.datasets_twoway import init_dataset
from models import *
from utils import save_checkpoint, mse_fn, per_element_rel_mse_fn, create_coordinate_grid, encode_coordinates
from visualization import visualize_results
import math

class BaseTrainer:
    def __init__(self, args, device, fig_train_save_path, param_save_path):
        self.args = args
        self.device = device
        self.fig_train_save_path = fig_train_save_path
        self.param_save_path = param_save_path
        self.valid_frame = 0
        
        self._init_data()
        self._init_mod_latent()
        self._init_modulator()
        self._init_model()
        self._init_training()
        
    def _init_data(self):

        self.train_dataset, _ = init_dataset(self.args)
        
        if self.args.normalize:
            self.x_normalizer, self.y_normalizer = self.train_dataset.get_normalizers()
        else:
            self.x_normalizer, self.y_normalizer = None, None
            
        if self.args.dataset == 'fwi':
            self.args.seis_min, self.args.seis_max, self.args.vel_min, self.args.vel_max = self.train_dataset.get_minmax()
            
        self.train_loader = DataLoader(self.train_dataset, batch_size=self.args.batch_size, shuffle=True)
        self.train_ts_obs = create_coordinate_grid(self.args.seq_len1_obs, self.args.seq_len2_obs, self.device, self.args.dataset)
        self.train_ts_src = create_coordinate_grid(self.args.seq_len1_src, self.args.seq_len2_src, self.device, self.args.dataset)
    def _init_mod_latent(self):

        self.latents = nn.ParameterList([
            nn.Parameter(torch.zeros(self.args.latent_dim).to(self.device)) 
            for _ in range(self.args.ntrain)
        ]).to(self.device)
        
        if self.args.inr == 'wire':
            self.args.hidden_dim = int(self.args.hidden_dim/np.sqrt(2))
        
        if self.args.modulation in ['shift', 'scale']:
            self.mod_dim = self.args.hidden_dim
        elif self.args.modulation == 'film':
            self.mod_dim = self.args.hidden_dim * 2
        elif self.args.modulation == 'gfm':
            self.mod_dim = self.args.n_fourier_bases * 2
            
    
    def _init_model(self):
        raise NotImplementedError
        
    def _init_training(self):
        meta_params = (list(self.Modulator_obs.parameters()) + 
                      list(self.Modulator_src.parameters()) + 
                      list(self.INR_obs.parameters()) + 
                      list(self.INR_src.parameters()))
        
        self.optim_latent = optim.Adam([{'params': self.latents, 'lr': self.args.lr_inner}])
        self.optim_meta = optim.Adam([{'params': meta_params, 'lr': self.args.lr_outer}])
        
        if self.args.dataset in ['airfoil', 'pipe']:
            self.compute_loss = per_element_rel_mse_fn
        else:
            self.compute_loss = mse_fn
            
    def _init_modulator(self):
        layers = []
        in_dim = self.args.latent_dim
        for i in range(self.args.num_mod_layers - 1):
            layers.append(nn.Linear(in_dim, self.mod_dim))
            layers.append(nn.ReLU(inplace=True))
            in_dim = self.mod_dim
        if self.args.inr == 'functa':
            layers.append(nn.Linear(self.mod_dim, (self.mod_dim * (self.args.num_layers+1))))
        else:
            layers.append(nn.Linear(self.mod_dim, (self.mod_dim * self.args.num_layers)))
        self.Modulator_obs = nn.Sequential(*layers).to(self.device)
        
        layers = []
        in_dim = self.args.latent_dim
        for i in range(self.args.num_mod_layers - 1):
            layers.append(nn.Linear(in_dim, self.mod_dim))
            layers.append(nn.ReLU(inplace=True))
            in_dim = self.mod_dim
        if self.args.inr == 'functa':
            layers.append(nn.Linear(self.mod_dim, (self.mod_dim * (self.args.num_layers+1))))
        else:
            layers.append(nn.Linear(self.mod_dim, (self.mod_dim * self.args.num_layers)))
        self.Modulator_src = nn.Sequential(*layers).to(self.device)
            
    def train_step(self, input_data, output_data, data_idx):
        train_obs = input_data.to(self.device)
        train_src = output_data.to(self.device)

        if self.args.sample_num >= len(data_idx):
            batch_sample_lst = random.sample(range(len(data_idx)), len(data_idx))
        else:
            batch_sample_lst = random.sample(range(len(data_idx)), self.args.sample_num)

        for j in batch_sample_lst:
            for _ in range(self.args.maxiter_inner):
                self.optim_latent.zero_grad()
                
                mod_obs = self.get_modulation(self.latents[data_idx[j]], 'obs')
                mod_src = self.get_modulation(self.latents[data_idx[j]], 'src')

                pred_obs = self.INR_obs(encode_coordinates(self.train_ts_obs, self.args.n_fourier, self.args.inr), mod_obs).reshape(self.args.seq_len1_obs, self.args.seq_len2_obs, self.args.out_dim_obs)
                pred_src = self.INR_src(encode_coordinates(self.train_ts_src, self.args.n_fourier, self.args.inr), mod_src).reshape(self.args.seq_len1_src, self.args.seq_len2_src, self.args.out_dim_src)

                loss = self.compute_loss(pred_obs.unsqueeze(0), train_obs[j].unsqueeze(0)) + self.compute_loss(pred_src.unsqueeze(0), train_src[j].unsqueeze(0))
                loss.backward()
                self.optim_latent.step()

        self.optim_meta.zero_grad()
        loss_obs = 0
        loss_src = 0
        
        for j in range(len(data_idx)):
            mod_obs = self.get_modulation(self.latents[data_idx[j]], 'obs')
            mod_src = self.get_modulation(self.latents[data_idx[j]], 'src')
            
            pred_obs = self.INR_obs(encode_coordinates(self.train_ts_obs, self.args.n_fourier, self.args.inr), mod_obs).reshape(self.args.seq_len1_obs, self.args.seq_len2_obs, self.args.out_dim_obs)
            pred_src = self.INR_src(encode_coordinates(self.train_ts_src, self.args.n_fourier, self.args.inr), mod_src).reshape(self.args.seq_len1_src, self.args.seq_len2_src, self.args.out_dim_src)
                
            loss_obs += self.compute_loss(pred_obs.unsqueeze(0), train_obs[j].unsqueeze(0))
            loss_src += self.compute_loss(pred_src.unsqueeze(0), train_src[j].unsqueeze(0))
            
        loss_total = loss_obs + loss_src
        loss_total.backward()
        self.optim_meta.step()
        
        return loss_obs / len(data_idx), loss_src / len(data_idx), loss_total / len(data_idx)
    
    def visualize(self, vis_input_data, vis_output_data, frame, x_normalizer, y_normalizer):
        with torch.no_grad():
            self.Modulator_obs.eval()
            self.Modulator_src.eval()
            self.INR_obs.eval()
            self.INR_src.eval()
            self.latents.eval()
            
            train_obs = vis_input_data.to(self.device)
            train_src = vis_output_data.to(self.device)
            num_vis = vis_input_data.shape[0]
            
            total_pred_obs = torch.zeros(num_vis, self.args.seq_len1_obs, self.args.seq_len2_obs, self.args.out_dim_obs)
            total_pred_src = torch.zeros(num_vis, self.args.seq_len1_src, self.args.seq_len2_src, self.args.out_dim_src)
            for fig_idx in range(num_vis):
                mod_obs = self.get_modulation(self.latents[fig_idx], 'obs')
                mod_src = self.get_modulation(self.latents[fig_idx], 'src')
                
                pred_obs = self.INR_obs(encode_coordinates(self.train_ts_obs, self.args.n_fourier, self.args.inr), mod_obs).reshape(self.args.seq_len1_obs, self.args.seq_len2_obs, -1)
                pred_src = self.INR_src(encode_coordinates(self.train_ts_src, self.args.n_fourier, self.args.inr), mod_src).reshape(self.args.seq_len1_src, self.args.seq_len2_src, -1)
                
                total_pred_obs[fig_idx] = pred_obs
                total_pred_src[fig_idx] = pred_src
                
            data = (train_obs.cpu(), train_src.cpu())
            predictions = (total_pred_obs.cpu(), total_pred_src.cpu())
            visualize_results(self.args, frame, data, predictions, self.fig_train_save_path, x_normalizer, y_normalizer)
            
            return total_pred_obs, total_pred_src

    def get_modulation(self, latent, mod_type):
        if mod_type == 'obs':
            mod = self.Modulator_obs(latent)
        else:
            mod = self.Modulator_src(latent)
            
        if self.args.inr == 'functa':
            return mod.reshape(self.mod_dim, (self.args.num_layers+1))
        else:
            return mod.reshape(-1, self.mod_dim, self.args.num_layers)
        
    def get_prediction(self, grid, mod, pred_type):
        if pred_type == 'obs':
            pred = self.INR_obs(grid, mod)
        else:
            pred = self.INR_src(grid, mod)
            
        output_dim = self.args.out_dim_obs if pred_type == 'obs' else self.args.out_dim_src
        seq_len1 = self.args.seq_len1_obs if pred_type == 'obs' else self.args.seq_len1_src
        seq_len2 = self.args.seq_len2_obs if pred_type == 'obs' else self.args.seq_len2_src
        
        return pred.reshape(seq_len1, seq_len2, output_dim)
    
    def train(self):
        best_train_loss = float('inf')
        frame = 0
        vis_input_data, vis_output_data = self.train_dataset.get_vis_data()
        
        for itr_outer in range(self.args.maxiter_outer):
            self.Modulator_obs.train()
            self.Modulator_src.train()
            self.INR_obs.train()
            self.INR_src.train()
            self.latents.train()
            
            train_loss_total = 0

            for input_data, output_data, data_idx in self.train_loader:
                loss_obs, loss_src, loss_total = self.train_step(input_data, output_data, data_idx)
                print(f'<Iter {itr_outer}> Loss [OBS]: {loss_obs.item():.4f}, [src]: {loss_src.item():.4f} | [TOTAL]: {loss_total.item():.4f}')
                train_loss_total += loss_total.item()
            
            is_train_best = train_loss_total < best_train_loss
            if is_train_best:
                best_train_loss = train_loss_total
                state = {
                    'INR_obs': self.INR_obs.state_dict(),
                    'INR_src': self.INR_src.state_dict(),
                    'Modulator_obs': self.Modulator_obs.state_dict(),
                    'Modulator_src': self.Modulator_src.state_dict(),
                    'Latents': self.latents.state_dict(),
                    'loss': loss_total.item()
                }
                
                save_checkpoint(state, is_train_best, None, self.param_save_path, itr_outer, best_train_loss, None) 
                
            
            if (itr_outer+1) % 50 == 1:
                self.visualize(vis_input_data, vis_output_data, frame, self.x_normalizer, self.y_normalizer)
                frame += 5
                
                
                
class SIRENTrainer(BaseTrainer):
    def _init_model(self):
        if self.args.modulation == 'gfm':
            self.INR_obs = INR_GFM(
                in_features=self.args.enc_dim,
                out_features=self.args.out_dim_obs,
                num_hidden_layers=self.args.num_layers,
                hidden_features=self.args.hidden_dim,
                outermost_linear=True,
                nonlinearity='sine',
                scale=self.args.scale,
                weight_init=None,
                n_fourier_bases=self.args.n_fourier_bases,
                high_freq=self.args.high_freq,
                low_freq=self.args.low_freq,
                phi_dim=self.args.phi_dim
            ).to(self.device)
            
            self.INR_src = INR_GFM(
                in_features=self.args.enc_dim,
                out_features=self.args.out_dim_src,
                num_hidden_layers=self.args.num_layers,
                hidden_features=self.args.hidden_dim,
                outermost_linear=True,
                nonlinearity='sine',
                scale=self.args.scale,
                weight_init=None,
                n_fourier_bases=self.args.n_fourier_bases,
                high_freq=self.args.high_freq,
                low_freq=self.args.low_freq,
                phi_dim=self.args.phi_dim
            ).to(self.device)

class FunctaTrainer(BaseTrainer):
    def _init_model(self):
        self.INR_obs = Functa(
            in_features=self.args.enc_dim,
            out_features=self.args.out_dim_obs,
            num_hidden_layers=self.args.num_layers,
            hidden_features=self.args.hidden_dim,
            mod_type=self.args.modulation,
            mod_dim=self.mod_dim
        ).to(self.device)
        
        self.INR_src = Functa(
            in_features=self.args.enc_dim,
            out_features=self.args.out_dim_src,
            num_hidden_layers=self.args.num_layers,
            hidden_features=self.args.hidden_dim,
            mod_type=self.args.modulation,
            mod_dim=self.mod_dim
        ).to(self.device)
        
        print(f"Modulation: {self.args.modulation}")
        print(f"Modulator_obs: {self.Modulator_obs}")
        print(f"Modulator_src: {self.Modulator_src}")
        print(f"INR_obs: {self.INR_obs}")
        print(f"INR_src: {self.INR_src}")
                        