from json import load
import os, sys, time
import numpy as np
import argparse
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torchvision.transforms import functional
import wandb
import matplotlib.pyplot as plt
from datetime import datetime
import logging
from utils import logging_utils
logging_utils.config_logger()
from utils.YParams import YParams
from utils.data_utils import get_data_loader
from utils.optimizer_utils import set_scheduler, set_optimizer
from utils.loss_utils import LossMSE
from utils.misc_utils import compute_grad_norm, vis_fields, l2_err, vis_field_single, vis_attention_single, vis_fields_named
from utils.domains import DomainXY
from utils.sweeps import sweep_name_suffix
from utils.trainer import Trainer, set_seed, count_parameters
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap as ruamelDict
from collections import OrderedDict

# models
import models.ffn
import models.fno


class Pretrainer(Trainer):
    """ pretrainer class """

    def build_and_run(self):

        if self.sweep_id and dist.is_initialized():
            # Broadcast sweep config to other ranks
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
            rank = comm.Get_rank()
            assert self.world_rank == rank
            if rank != 0:
                self.params = None
            self.params = comm.bcast(self.params, root=0)
            self.params.device = self.device # dont broadcast 0s device

        if self.world_rank == 0:
            logging.info(self.params.log())

        set_seed(self.params, self.world_size)

        self.params['global_batch_size'] = self.params.batch_size
        self.params['local_batch_size'] = int(self.params.batch_size//self.world_size)
        self.params['global_valid_batch_size'] = self.params.valid_batch_size
        self.params['local_valid_batch_size'] = int(self.params.valid_batch_size//self.world_size)

        # dump the yaml used
        if self.world_rank == 0:
            hparams = ruamelDict()
            yaml = YAML()
            for key, value in self.params.params.items():
                hparams[str(key)] = str(value)
            with open(os.path.join(self.params['experiment_dir'], 'hyperparams.yaml'), 'w') as hpfile:
                yaml.dump(hparams, hpfile)

        self.train_data_loader, self.train_dataset, self.train_sampler = get_data_loader(self.params, self.params.train_path, dist.is_initialized(), train=True, pack=self.params.pack_data, masking=self.params['mask_ratio'])
        self.val_data_loader, self.val_dataset, self.valid_sampler = get_data_loader(self.params, self.params.val_path, dist.is_initialized(), train=False, pack=self.params.pack_data, masking=self.params['mask_ratio'])
        self.test_data_loader, self.test_dataset, self.test_sampler = get_data_loader(self.params, self.params.test_path, dist.is_initialized(), train=False, pack=self.params.pack_data, masking=self.params['mask_ratio'])

        # domain grid
        self.domain = DomainXY(self.params)

        
        if self.params.model == 'fno':
            self.model = models.fno.fno_pretrain(self.params).to(self.device)
        else:
            assert(False), "Error, model arch invalid."

        if dist.is_initialized():
            self.model = DistributedDataParallel(self.model,
                                                device_ids=[self.local_rank],
                                                output_device=[self.local_rank])

        self.optimizer = set_optimizer(self.params, self.model)

        self.scheduler = set_scheduler(self.params, self.optimizer)

        if self.params.loss_func == "mse":
            self.loss_func = LossMSE(self.params, self.model)
        else:
            assert(False), "Error,  loss func invalid."

        self.iters = 0
        self.startEpoch = 0

        if hasattr(self.params, 'weights'):
            self.params.resuming = False
            logging.info("Loading IC weights %s"%self.params.weights)
            self.load_model(self.params.weights)

        if self.params.resuming:
            logging.info("Loading checkpoint %s"%self.params.checkpoint_path)
            self.restore_checkpoint(self.params.checkpoint_path)

        self.epoch = self.startEpoch
        self.logs = {}
        self.train_loss = self.data_loss = self.bc_loss = self.pde_loss = self.grad = 0.0
        n_params = count_parameters(self.model)
        if self.log_to_screen:
            logging.info(self.model)
            logging.info('number of model parameters: {} M'.format(n_params))

        # launch training
        self.train()


    def train(self):
        if self.log_to_screen:
            logging.info("Starting training loop...")
        best_loss = np.inf
        best_loss_test = np.inf

        best_epoch = 0
        best_err = 1
        best_err_test = 1
        self.logs['best_epoch'] = best_epoch
        plot_figs = self.params.plot_figs

        for epoch in range(self.startEpoch, self.params.max_epochs):
            self.epoch = epoch
            if dist.is_initialized():
                # shuffles data before every epoch
                self.train_sampler.set_epoch(epoch)
            start = time.time()

            # train
            tr_time = self.train_one_epoch()
            val_time, fields = self.val_one_epoch()
            if self.log_to_wandb:
                # log visualizations every epoch
                if plot_figs:
                    fig = vis_fields_named(fields, self.params, self.domain)
                    self.logs['vis'] = wandb.Image(fig)
                    plt.close(fig)
                        

            test_time, fields_test = self.val_one_epoch(test=True)
            if self.log_to_wandb:
                # log visualizations every epoch
                if plot_figs:
                    fig = vis_fields_named(fields_test, self.params, self.domain)
                    self.logs['vis_test'] = wandb.Image(fig)
                    plt.close(fig)

            self.logs['wt_norm'] = self.get_model_wt_norm(self.model)

            if self.params.scheduler == 'reducelr':
                self.scheduler.step(self.logs['train_loss'])
            elif self.params.scheduler == 'cosine':
                self.scheduler.step()

            if self.logs['val_loss'] <= best_loss:
                is_best_loss = True
                best_loss = self.logs['val_loss']
                best_err = self.logs['val_err']
                best_loss_test = self.logs['test_loss']
                best_err_test = self.logs['test_err']
            else:
                is_best_loss = False
            self.logs['best_val_loss'] = best_loss
            self.logs['best_val_err'] = best_err
            self.logs['best_test_loss'] = best_loss_test
            self.logs['best_test_err'] = best_err_test

            best_epoch = self.epoch if is_best_loss else best_epoch
            self.logs['best_epoch'] = best_epoch

            if self.params.save_checkpoint:
                if self.world_rank == 0:
                    #checkpoint at the end of every epoch
                    if is_best_loss:
                        self.save_logs(tag="_best")
                    self.save_checkpoint(self.params.checkpoint_path, is_best=is_best_loss)
                    if epoch in [50, 100, 200, 300, 400]:
                        self.save_checkpoint(self.params.checkpoint_path.replace('.tar', f'_{epoch}.tar'), is_best=False)

            if self.log_to_wandb:
                self.logs['learning_rate'] = self.optimizer.param_groups[0]['lr']
                self.logs['time_per_epoch'] = tr_time
                wandb.log(self.logs, step=self.epoch+1)

            if self.log_to_screen:
                logging.info('Time taken for epoch {} is {} sec; with {}/{}/{} in tr/val/test'.format(self.epoch+1, time.time()-start, tr_time, val_time, test_time))
                logging.info('Loss (total = data) {} = {}'.format(self.logs['train_loss'], self.logs['data_loss']))


        if self.log_to_wandb:
            wandb.finish()

    
    def train_one_epoch(self):
        tr_time = 0
        self.model.train()

        # buffers for logs
        logs_buff = torch.zeros((6), dtype=torch.float32, device=self.device)
        self.logs['train_loss'] = logs_buff[0].view(-1)
        self.logs['data_loss'] = logs_buff[1].view(-1)
        self.logs['bc_loss'] = logs_buff[2].view(-1)
        self.logs['pde_loss'] = logs_buff[3].view(-1)
        self.logs['grad'] = logs_buff[4].view(-1)
        self.logs['tr_err'] = logs_buff[5].view(-1)


        for i, (inputs, targets, masks) in enumerate(self.train_data_loader):
            self.iters += 1
            data_start = time.time()
            if not self.params.pack_data: # send to gpu if not already packed in the dataloader
                # inputs, targets, masks = inputs.to(self.device), targets.to(self.device), masks.to(self.device)
                inputs, masks = inputs.to(self.device), masks.to(self.device)

            if hasattr(self.params, "blur") and sum(self.params['blur']) > 0:
                inp_blur = []
                for _i in range(len(inputs)):
                    _inp = inputs[_i:_i+1].detach().clone()
                    sigma = random.uniform(*self.params['blur'])
                    # https://github.com/scipy/scipy/blob/v1.11.4/scipy/ndimage/_filters.py#L232
                    _kernel = min(int((sigma*4+1)/2)*2+1, (_inp.shape[2]//2)*2-1)
                    if _kernel >= 2:
                        _inp = functional.gaussian_blur(_inp, kernel_size=[_kernel, _kernel], sigma=sigma)
                    inp_blur.append(_inp)
                inp_blur = torch.cat(inp_blur, dim=0).contiguous()
            else:
                inp_blur = inputs.detach().clone()

            tr_start = time.time()

            self.model.zero_grad()

            u = self.model(inp_blur, masks)
            if self.params['mask_ratio'] > 0:
                loss_data = self.loss_func.data(inputs, u*(1-masks), inputs*(1-masks)) 
            else:
                loss_data = self.loss_func.data(inputs, u, inputs) 

            loss = loss_data

            loss.backward()
            self.optimizer.step()

            grad_norm = compute_grad_norm(self.model)
            if self.params['mask_ratio'] > 0:
                tr_err = l2_err(u.detach()*(1-masks.detach()), inputs.detach()*(1-masks.detach()))
            else:
                tr_err = l2_err(u.detach(), inputs.detach())
            
    
            # add all the minibatch losses
            self.logs['train_loss'] += loss.detach()
            self.logs['data_loss'] += loss_data.detach()
            self.logs['grad'] += grad_norm
            self.logs['tr_err'] += tr_err

            tr_time += time.time() - tr_start

        self.logs['train_loss'] /= len(self.train_data_loader)
        self.logs['data_loss'] /= len(self.train_data_loader)
        self.logs['grad'] /= len(self.train_data_loader)
        self.logs['tr_err'] /= len(self.train_data_loader)

        logs_to_reduce = ['train_loss', 'data_loss', 'grad', 'tr_err']

        if dist.is_initialized():
            for key in logs_to_reduce:
                dist.all_reduce(self.logs[key].detach())
                # todo change loss to unscaled
                self.logs[key] = float(self.logs[key]/dist.get_world_size())

        return tr_time

    def val_one_epoch(self, test=False):
        self.model.eval() # need gradients
        val_start = time.time()
        log_key = "test" if test else "val"

        logs_buff = torch.zeros((2), dtype=torch.float32, device=self.device)
        self.logs['%s_err'%log_key] = logs_buff[0].view(-1)
        self.logs['%s_loss'%log_key] = logs_buff[1].view(-1)
        if test:
            idx = np.random.randint(0, len(self.test_data_loader))
            loader = self.test_data_loader
        else:
            idx = np.random.randint(0, len(self.val_data_loader))
            loader = self.val_data_loader
        img_idx = np.random.randint(0, self.params.local_valid_batch_size)
        with torch.no_grad():
            for i, (inputs, targets, masks) in enumerate(loader):
                if not self.params.pack_data:
                    inputs, targets, masks = inputs.to(self.device), targets.to(self.device), masks.to(self.device)
                u = self.model(inputs, masks)
                if self.params['mask_ratio'] > 0:
                    loss_data = self.loss_func.data(inputs, u*(1-masks), inputs*(1-masks)) 
                else:
                    loss_data = self.loss_func.data(inputs, u, inputs) 
                loss = loss_data
                if self.params['mask_ratio'] > 0:
                    _err = l2_err(u.detach()*(1-masks.detach()), inputs.detach()*(1-masks.detach()))
                else:
                    _err = l2_err(u.detach(), inputs.detach())
                self.logs['%s_err'%log_key] += _err
                self.logs['%s_loss'%log_key] += loss.detach()
                if i == idx: 
                    source = inputs[img_idx,0].detach().cpu().numpy() 
                    soln = targets[img_idx,0].detach().cpu().numpy()
                    pred = u[img_idx,0].detach().cpu().numpy()
                    mask = masks[img_idx,0].detach().cpu().numpy()
                    error = torch.abs(u - inputs)[img_idx,0].detach().cpu().numpy()

        fields = {"source": source, "target": soln, "pred": pred, "mask": mask, "error": error}

        self.logs['%s_loss'%log_key] /= len(loader)
        self.logs['%s_err'%log_key] /= len(loader)
        if dist.is_initialized():
            for key in ['%s_loss'%log_key, '%s_err'%log_key]:
                dist.all_reduce(self.logs[key].detach())
                self.logs[key] = float(self.logs[key]/dist.get_world_size())

        val_time = time.time() - val_start

        return val_time, fields

    def save_checkpoint(self, checkpoint_path, is_best=False, model=None):
        if not model:
            model = self.model
        torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': (self.scheduler.state_dict() if self.scheduler is not None else None)}, checkpoint_path)
        torch.save({'iters': self.iters, 'epoch': self.epoch,
            'model_state': {k.replace("encoder", "backbone"): v for k, v in model.state_dict().items() if "encoder" in k},
            'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': (self.scheduler.state_dict() if self.scheduler is not None else None)}, checkpoint_path.replace("ckpt", "backbone"))
        if is_best:
            torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': (self.scheduler.state_dict() if  self.scheduler is not None else None)}, checkpoint_path.replace('.tar', '_best.tar'))
            torch.save({'iters': self.iters, 'epoch': self.epoch,
                'model_state': {k.replace("encoder", "backbone"): v for k, v in model.state_dict().items() if "encoder" in k},
                'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': (self.scheduler.state_dict() if self.scheduler is not None else None)}, checkpoint_path.replace("ckpt", "backbone").replace('.tar', '_best.tar'))
