import argparse
import os
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
import torch.cuda.amp as amp
from torch.nn.parallel import DistributedDataParallel
from einops import rearrange
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap as ruamelDict
from dadaptation import DAdaptAdam, DAdaptAdan
from adan_pytorch import Adan
from collections import OrderedDict
import wandb
import logging
import pickle as pkl
import gc
from torchinfo import summary
from collections import defaultdict
try:
    from data_utils.datasets import get_data_loader, DSET_NAME_TO_OBJECT
    from models.avit import build_avit
    from models.vmae import build_vmae
    from utils import logging_utils
    from utils.YParams import YParams
    from utils.load_ckpt_utils import load_ckpt
except:
    from .data_utils.datasets import get_data_loader, DSET_NAME_TO_OBJECT
    from .models.avit import build_avit
    from .models.vmae import build_vmae
    from .utils import logging_utils
    from .utils.YParams import YParams
    from .utils.load_ckpt_utils import load_ckpt
from train_basic import Trainer
from tqdm import tqdm
from pdb import set_trace as bp


def grad_norm(parameters):
    with torch.no_grad():
        total_norm = 0
        for p in parameters:
            if p.grad is not None:
                total_norm += p.grad.data.pow(2).sum().item()
        return total_norm**.5

def grad_clone(parameters):
    with torch.no_grad():
        clones = []
        for p in parameters:
            if p.grad is not None:
                clones.append(p.grad.clone())
            else:
                clones.append(torch.zeros_like(p))
        return clones

def param_norm(parameters):
    with torch.no_grad():
        total_norm = 0
        for p in parameters:
            total_norm += p.pow(2).sum().item()
        return total_norm**.5

def param_diff(params1, params2):
    with torch.no_grad():
        total_norm = 0
        for p1, p2 in zip(params1, params2):
            total_norm += (p2-p1).pow(2).sum().item()
        return total_norm**.5

def add_weight_decay(model, weight_decay=1e-5, inner_lr=1e-3, skip_list=()):
    """ From Ross Wightman at:
        https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/3 
        
        Goes through the parameter list and if the squeeze dim is 1 or 0 (usually means bias or scale) 
        then don't apply weight decay. 
        """
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if (len(param.squeeze().shape) <= 1 or name in skip_list):
            no_decay.append(param)
        else:
            decay.append(param)
    return [
            {'params': no_decay, 'weight_decay': 0.,},
            {'params': decay, 'weight_decay': weight_decay}]

class Inferencer(Trainer):
    def __init__(self, params, global_rank, local_rank, device, sweep_id=None):
        self.device = device
        self.params = params
        self.global_rank = global_rank
        self.local_rank = local_rank
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.sweep_id = sweep_id
        self.log_to_screen = params.log_to_screen
        # Basic setup
        self.debug_grad = params.debug_grad
        self.mp_type = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.half
        self.rollout = getattr(params, "rollout", 1)

        self.iters = 0
        self.initialize_data(self.params)
        print(f"Initializing model on rank {self.global_rank}")
        self.initialize_model(self.params)
        if params.resuming:
            print("Loading checkpoint %s"%params.checkpoint_path)
            print('LOADING CHECKPOINTTTTTT')
            self.restore_checkpoint(params.checkpoint_path)
        if params.resuming == False and params.pretrained:
            print("Starting from pretrained model at %s"%params.pretrained_ckpt_path)
            self.restore_checkpoint(params.pretrained_ckpt_path)


    def initialize_data(self, params):
        if params.tie_batches:
            in_rank = 0
        else:
            in_rank = self.global_rank
        if self.log_to_screen:
            print(f"Initializing data on rank {self.global_rank}")
        self.valid_data_loader, self.valid_dataset, _ = get_data_loader(params, params.valid_data_paths, dist.is_initialized(), split='val', rank=in_rank)
        self.test_data_loader, self.test_dataset, _ = get_data_loader(params, params.valid_data_paths, dist.is_initialized(), split='test', rank=in_rank)
        if params.num_demos > 0:
            loader = torch.utils.data.DataLoader(self.valid_dataset.sub_dsets[0].get_per_file_dsets()[0], batch_size=params.num_demos, num_workers=self.params.num_data_workers, shuffle=True, generator= torch.Generator().manual_seed(0), drop_last=True)
            self.demo_xs, _, self.demo_ys = map(lambda x: x.to(self.device), next(iter(loader)))
            self.demo_xs = rearrange(self.demo_xs, 'b t c h w -> t b c h w')


    def restore_checkpoint(self, checkpoint_path):
        """ Load model/opt from path """
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank))
        try:
            self.model.load_state_dict(checkpoint['model_state'])
        except:
            new_state_dict = OrderedDict()
            for key, val in checkpoint['model_state'].items():
                name = key[7:]
                new_state_dict[name] = val
            self.model.load_state_dict(new_state_dict)
        self.iters = checkpoint['iters']
        if self.params.pretrained:
            if self.params.freeze_middle:
                self.model.module.freeze_middle()
            elif self.params.freeze_processor:
                self.model.module.freeze_processor()
            else:
                self.model.module.unfreeze()
            # See how much we need to expand the projections
            exp_proj = 0
            # Iterate through the appended datasets and add on enough embeddings for all of them. 
            for add_on in self.params.append_datasets:
                exp_proj += len(DSET_NAME_TO_OBJECT[add_on]._specifics()[2])
            self.model.module.expand_projections(exp_proj)
        checkpoint = None


    def validate_one_epoch(self, dataset, full=False):
        """
        Validates - for each batch just use a small subset to make it easier.

        Note: need to split datasets for meaningful metrics, but TBD. 
        """
        # Don't bother with full validation set between epochs
        self.model.eval()
        if full:
            cutoff = 999999999999
        else:
            cutoff = 40
        self.single_print('STARTING VALIDATION!!!')
        with torch.inference_mode():
            # There's something weird going on when i turn this off.
            with amp.autocast(False, dtype=self.mp_type):
                field_labels = dataset.get_state_names()
                distinct_dsets = list(set([dset.title for dset_group in dataset.sub_dsets 
                                           for dset in dset_group.get_per_file_dsets()]))
                counts = {dset: 0 for dset in distinct_dsets}
                logs = {} # 
                # Iterate through all folder specific datasets
                for subset_group in dataset.sub_dsets:
                    for subset in subset_group.get_per_file_dsets():
                        dset_type = subset.title
                        self.single_print('VALIDATING ON', dset_type)
                        # Create data loader for each
                        if self.params.use_ddp:
                            temp_loader = torch.utils.data.DataLoader(subset, batch_size=self.params.batch_size,
                                                                    num_workers=self.params.num_data_workers,
                                                                    sampler=torch.utils.data.distributed.DistributedSampler(subset,
                                                                                                                            drop_last=True)
                                    )
                        else:
                            # Seed isn't important, just trying to mix up samples from different trajectories
                            temp_loader = torch.utils.data.DataLoader(subset, batch_size=self.params.batch_size,
                                                                    num_workers=self.params.num_data_workers, 
                                                                    shuffle=False, generator= torch.Generator().manual_seed(0),
                                                                    drop_last=True)
                        count = 0
                        # for batch_idx, data in enumerate(temp_loader):
                        for batch_idx, data in tqdm(enumerate(temp_loader), total=len(temp_loader)):
                            # Only do a few batches of each dataset if not doing full validation
                            if count > cutoff:
                                del(temp_loader)
                                break
                            inp, bcs, tar = map(lambda x: x.to(self.device), data)
                            if inp.shape[1] + 1 < self.params.n_steps + self.rollout: continue
                            count += 1
                            counts[dset_type] += 1
                            tar_all = torch.cat([inp, tar.unsqueeze(1)], dim=1)
                            inp = rearrange(inp, 'b t c h w -> t b c h w')

                            self.model.target = tar # TODO: just for debugging purpose
                            nmse = 0

                            # https://github.com/pdebench/PDEBench/blob/main/pdebench/models/metrics.py#L287

                            xx = inp[:self.params.n_steps]
                            for t in range(self.rollout):
                                tar = tar_all[:, self.params.n_steps+t] # TODO:
                                if self.params.num_demos == 0:
                                    output = self.model(xx)
                                else:
                                    output = self.model.forward_icl(inp, self.demo_xs, self.demo_ys)
                                xx = torch.cat((xx[1:], output.unsqueeze(0)), dim=0)

                                # I don't think this is the true metric, but PDE bench averages spatial RMSE over batches (MRMSE?) rather than root after mean
                                # And we want the comparison to be consistent
                                spatial_dims = tuple(range(output.ndim))[2:] # Assume 0, 1, 2 are T, B, C
                                residuals = output - tar
                                nmse += (residuals.pow(2).mean(spatial_dims, keepdim=True) 
                                        / (1e-7 + tar.pow(2).mean(spatial_dims, keepdim=True))).sqrt()#.mean()
                            nmse /= self.rollout

                            logs[f'{dset_type}/valid_nrmse'] = logs.get(f'{dset_type}/valid_nrmse',0) + nmse.mean()
                            logs[f'{dset_type}/valid_rmse'] = (logs.get(f'{dset_type}/valid_mse',0) 
                                                                + residuals.pow(2).mean(spatial_dims).sqrt().mean())
                            logs[f'{dset_type}/valid_l1'] = (logs.get(f'{dset_type}/valid_l1', 0) 
                                                                + residuals.abs().mean())

                            for i, field in enumerate(dataset.subset_dict[subset.type]):
                                field_name = field_labels[field]
                                logs[f'{dset_type}/{field_name}_valid_nrmse'] = (logs.get(f'{dset_type}/{field_name}_valid_nrmse', 0) 
                                                                                + nmse[:, i].mean())
                                logs[f'{dset_type}/{field_name}_valid_rmse'] = (logs.get(f'{dset_type}/{field_name}_valid_rmse', 0) 
                                                                                    + residuals[:, i:i+1].pow(2).mean(spatial_dims).sqrt().mean())
                                logs[f'{dset_type}/{field_name}_valid_l1'] = (logs.get(f'{dset_type}/{field_name}_valid_l1', 0) 
                                                                            +  residuals[:, i].abs().mean())
                        else:
                            del(temp_loader)

            self.single_print('DONE VALIDATING - NOW SYNCING')
            for k, v in logs.items():
                dset_type = k.split('/')[0]
                logs[k] = v/counts[dset_type]

            logs['valid_nrmse'] = 0
            for dset_type in distinct_dsets:
                logs['valid_nrmse'] += logs[f'{dset_type}/valid_nrmse']/len(distinct_dsets)
            
            if dist.is_initialized():
                for key in sorted(logs.keys()):
                    dist.all_reduce(logs[key].detach()) # There was a bug with means when I implemented this - dont know if fixed
                    logs[key] = float(logs[key].item()/dist.get_world_size())
                    if 'rmse' in key:
                        logs[key] = logs[key]
            self.single_print('DONE SYNCING - NOW LOGGING')
        return logs               


    def validate(self):
        # This is set up this way based on old code to allow wandb sweeps
        if self.params.log_to_wandb:
            if self.sweep_id:
                wandb.init(dir=self.params.experiment_dir)
                hpo_config = wandb.config.as_dict()
                self.params.update_params(hpo_config)
                params = self.params
            else:
                wandb.init(dir=self.params.experiment_dir, config=self.params, name=self.params.name, group=self.params.group, 
                           project=self.params.project, entity=self.params.entity, resume=True)
                
        if self.global_rank == 0:
            summary(self.model)
        if self.params.log_to_wandb:
            wandb.watch(self.model)
        self.single_print("Starting Training Loop...")
        # Actually train now, saving checkpoints, logging time, and logging to wandb
        best_valid_loss = 1.e6

        logs = {}
        valid_start = time.time()
        # Only do full validation set on last epoch - don't waste time
        # valid_logs = self.validate_one_epoch(self.valid_dataset, True)
        test_logs = self.validate_one_epoch(self.test_dataset, True)
        
        post_start = time.time()
        test_logs['time/valid_time'] = post_start-valid_start
        if self.params.log_to_wandb:
            wandb.log(test_logs)
        gc.collect()
        torch.cuda.empty_cache()

        if self.global_rank == 0:
            self.single_print(f'Time for valid: {post_start-valid_start}.')
            self.single_print('Test loss: {}'.format(test_logs['valid_nrmse']))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--run_name", default='00', type=str)
    parser.add_argument("--use_ddp", action='store_true', help='Use distributed data parallel')
    parser.add_argument("--yaml_config", default='./config/base_config.yaml', type=str)
    parser.add_argument("--config", default='basic_config', type=str)
    parser.add_argument("--sweep_id", default=None, type=str, help='sweep config from ./configs/sweeps.yaml')
    parser.add_argument("--num_demos", default=0, type=int, help='number of demos to use (from validation set).')
    parser.add_argument("--rollout", default=1, type=int, help='rollout steps.')
    parser.add_argument("--vmae_pretrained", default=None, type=str, help='path to pretrained ckpt.')
    args = parser.parse_args()
    params = YParams(os.path.abspath(args.yaml_config), args.config)
    params.use_ddp = args.use_ddp
    params.rollout = args.rollout
    if args.vmae_pretrained is not None:
        params.vmae_pretrained = args.vmae_pretrained
    # Set up distributed training
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    global_rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    if args.use_ddp:
        dist.init_process_group("nccl")
        torch.cuda.set_device(local_rank) # Torch docs recommend just using device, but I had weird memory issues without setting this.
    device = torch.device(local_rank) if torch.cuda.is_available() else torch.device("cpu")

    # Modify params
    params['batch_size'] = int(params.batch_size//world_size)
    params['startEpoch'] = 0
    if args.sweep_id:
        jid = os.environ['SLURM_JOBID'] # so different sweeps dont resume
        expDir = os.path.join(params.exp_dir, args.sweep_id, args.config, str(args.run_name), jid)
    else:
        expDir = os.path.join(params.exp_dir, args.config, str(args.run_name))

    params['old_exp_dir'] = expDir # I dont remember what this was for but not removing it yet
    params['experiment_dir'] = os.path.abspath(expDir)
    params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar')
    params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
    params['old_checkpoint_path'] = os.path.join(params.old_exp_dir, 'training_checkpoints/best_ckpt.tar')
    
    # Have rank 0 check for and/or make directory
    if  global_rank==0:
        if not os.path.isdir(expDir):
            os.makedirs(expDir)
            os.makedirs(os.path.join(expDir, 'training_checkpoints/'))
    params['resuming'] = True if os.path.isfile(params.checkpoint_path) else False

    # WANDB things
    params['name'] =  str(args.run_name)
    if global_rank==0:
        logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log'))
        logging_utils.log_versions()
        params.log()

    if global_rank==0:
        logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log'))
        logging_utils.log_versions()
        params.log()

    params['log_to_wandb'] = (global_rank==0) and params['log_to_wandb']
    params['log_to_screen'] = (global_rank==0) and params['log_to_screen']
    torch.backends.cudnn.benchmark = True

    if global_rank == 0:
        hparams = ruamelDict()
        yaml = YAML()
        for key, value in params.params.items():
            hparams[str(key)] = str(value)
        with open(os.path.join(expDir, 'hyperparams.yaml'), 'w') as hpfile:
            yaml.dump(hparams, hpfile)
    inferencer = Inferencer(params, global_rank, local_rank, device, sweep_id=args.sweep_id)
    if args.sweep_id and inferencer.global_rank==0:
        print(args.sweep_id, inferencer.params.entity, inferencer.params.project)
        wandb.agent(args.sweep_id, function=inferencer.train, count=1, entity=inferencer.params.entity, project=inferencer.params.project) 
    else:
        inferencer.validate()
    if params.log_to_screen:
        print('DONE ---- rank %d'%global_rank)
