import os
import argparse
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from timeit import default_timer
from attrdict import AttrDict

from task_configs import get_config, get_metric, get_optimizer_scheduler
from data_loaders_seq import load_pde 
from utils import count_params, count_trainable_params, denormalize
from embedder import get_tgt_model

def get_data(root, dataset, batch_size, valid_split, maxsize=None, get_shape=False, args=None):

    if dataset == "your_new_task": # modify this to experiment with a new task
        train_loader, val_loader, test_loader = None, None, None
    elif dataset == 'PDE':
        train_loader, val_loader, test_loader = load_pde(root, batch_size, subset=args.pde_subset, size=args.size)
    else:
        print("Data loader not implemented")

    n_train, n_val, n_test = len(train_loader), len(val_loader) if val_loader is not None else 0, len(test_loader)

    if not valid_split:
        val_loader = test_loader
        n_val = n_test

    return train_loader, val_loader, test_loader, n_train, n_val, n_test

def main(use_determined, args, info=None, context=None):

    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    root = '/datasets' if use_determined else './datasets'

    torch.cuda.empty_cache()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed) 
    torch.cuda.manual_seed_all(args.seed)

    if args.reproducibility:
        cudnn.deterministic = True
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True

    dims, sample_shape, num_classes, loss, args = get_config(root, args)

    if load_embedder(use_determined, args):
        args.embedder_epochs = 0

    model, embedder_stats = get_tgt_model(args, root, sample_shape, num_classes, loss, False, use_determined, context, eval_mode=True)
    args.pde_subset = "Burgers"
        
    train_loader, val_loader, test_loader, n_train, n_val, n_test, data_kwargs = get_data(root, args.dataset, args.batch_size, args.valid_split, args=args)
    metric, compare_metrics = get_metric(root, args.dataset)
    decoder = data_kwargs['decoder'] if data_kwargs is not None and 'decoder' in data_kwargs else None 
    transform = data_kwargs['transform'] if data_kwargs is not None and 'transform' in data_kwargs else None 
    
    model, ep_start, id_best, train_score, train_losses, embedder_stats_saved = load_state(use_determined, args, context, model, None, None, n_train, freq=args.validation_freq, test=True, dirname="best_")
    embedder_stats = embedder_stats if embedder_stats_saved is None else embedder_stats_saved
    
    offset = 0 if ep_start == 0 else 1
    args, model, optimizer, scheduler = get_optimizer_scheduler(args, model, module=None if args.predictor_epochs == 0 or ep_start >= args.predictor_epochs else 'predictor', n_train=n_train)
    train_full = args.predictor_epochs == 0 or ep_start >= args.predictor_epochs
    
    if args.device == 'cuda':
        model.cuda()
        try:
            loss.cuda()
        except:
            pass
        if decoder is not None:
            decoder.cuda()

    print("\n------- Experiment Summary --------")
    print("id:", args.experiment_id)
    print("dataset:", args.dataset, "\tbatch size:", args.batch_size, "\tlr:", args.optimizer.params.lr)
    print("num train batch:", n_train, "\tnum validation batch:", n_val, "\tnum test batch:", n_test)
    print("finetune method:", args.finetune_method)
    print("param count:", count_params(model), count_trainable_params(model))
    
    model, ep_start, id_best, train_score, train_losses, embedder_stats_saved = load_state(use_determined, args, context, model, optimizer, scheduler, n_train, freq=args.validation_freq, dirname="best_")
    embedder_stats = embedder_stats if embedder_stats_saved is None else embedder_stats_saved
    train_time = []

    print("\n------- Start Evaluation --------")
    
    eval_multiple(use_determined, context, root, args, model, loss, metric)
        

def eval_multiple(use_determined, context, root, args, model, loss, metric):
    for pde_subset in ["SW","DS","Burgers","ADV","1DCFD","2DCFD", "NS","RD","RD2D"]:
        args.pde_subset = pde_subset

        _, _, test_loader, _, _, n_test, data_kwargs = get_data(root, args.dataset, args.batch_size, args.valid_split, args=args)        
            
        test_time_start = default_timer()
        test_loss, test_score = evaluate(context, args, model, test_loader, loss, metric, n_test)
        test_time_end = default_timer()

        print("[test]",args.pde_subset, "\ttime elapsed:", "%.4f" % (test_time_end - test_time_start), "\ttest loss:", "%.4f" % test_loss, "\ttest score:", "%.4f" % test_score)
            


def evaluate(context, args, model, loader, loss, metric, n_eval):
    model.eval()
    
    eval_loss, eval_score = 0, 0
    
    ys, outs, n_eval, n_data = [], [], 0, 0
    masks = []

    with torch.no_grad():
        for i, data in enumerate(loader):
            x, y = data
            if isinstance(x, list):
                x, text_embeddings = x
                text_embeddings = text_embeddings.to(args.device)
            else:
                text_embeddings = None

            if isinstance(y, list):
                y, mask = y
                y = y.to(args.device)
                mask = mask.to(args.device)
                y *= mask
                masks.append(mask)
            else:
                y = y.to(args.device)
                mask = None

            seqout = []
            for seqi in range(y.shape[1]):

                x = x.to(args.device)

                out = model(x, text_embeddings=text_embeddings)

                if mask is not None:
                    out *= mask

                seqout.append(out)
                x = out.detach()

            seqout = torch.cat(seqout, 1)

            outs.append(seqout)
            ys.append(y)
            n_data += x.shape[0]

            if n_data >= args.eval_batch_size or i == len(loader) - 1:
                outs = torch.cat(outs, 0)
                ys = torch.cat(ys, 0)
                masks = torch.cat(masks, 0)
                outs, ys = denormalize(outs, ys, loader.dataset.mean, loader.dataset.std)
                outs *= masks
                ys *= masks
                eval_loss += loss(outs, ys).item()
                eval_score += metric(outs, ys).item()
                n_eval += 1

                ys, outs, n_data = [], [], 0
                masks = []

        eval_loss /= n_eval
        eval_score /= n_eval

    return eval_loss, eval_score


def load_embedder(use_determined, args):
    if not use_determined:
        path = 'results/'  + args.dataset +'/' + str(args.finetune_method) + '_' + str(args.experiment_id) + "/" + str(args.seed)
        return os.path.isfile(os.path.join(path, 'state_dict.pt'))
    else:

        info = det.get_cluster_info()
        checkpoint_id = info.latest_checkpoint
        return checkpoint_id is not None


def load_state(use_determined, args, context, model, optimizer, scheduler, n_train, checkpoint_id=None, test=False, freq=1, dirname="last"):
    if not use_determined:
        path = 'results/'  + args.dataset +'/' + str(args.finetune_method) + '_' + str(args.experiment_id) + "/" + dirname+str(args.seed) 
        if not os.path.isfile(os.path.join(path, 'state_dict.pt')):
            return model, 0, 0, [], [], None
    else:

        if checkpoint_id is None:
            info = det.get_cluster_info()
            checkpoint_id = info.latest_checkpoint
            if checkpoint_id is None:
                return model, 0, 0, [], [], None
        
        checkpoint = client.get_checkpoint(checkpoint_id)
        path = checkpoint.download()

    train_score = np.load(os.path.join(path, 'train_score.npy'))
    train_losses = np.load(os.path.join(path, 'train_losses.npy'))
    embedder_stats = np.load(os.path.join(path, 'embedder_stats.npy'))
    epochs = freq * (len(train_score) - 1) + 1
    checkpoint_id = checkpoint_id if use_determined else epochs - 1
    model_state_dict = torch.load(os.path.join(path, 'state_dict.pt'))
    model.load_state_dict(model_state_dict['network_state_dict'])
    
    if not test:
        optimizer.load_state_dict(model_state_dict['optimizer_state_dict'])
        scheduler.load_state_dict(model_state_dict['scheduler_state_dict'])

        rng_state_dict = torch.load(os.path.join(path, 'rng_state.ckpt'), map_location='cpu')
        torch.set_rng_state(rng_state_dict['cpu_rng_state'])
        torch.cuda.set_rng_state(rng_state_dict['gpu_rng_state'])
        np.random.set_state(rng_state_dict['numpy_rng_state'])
        random.setstate(rng_state_dict['py_rng_state'])

        if use_determined: 
            try:
                for ep in range(epochs):
                    if ep % freq == 0:
                        context.train.report_training_metrics(steps_completed=(ep + 1) * n_train, metrics={"train loss": train_losses[ep // freq]})
                        context.train.report_validation_metrics(steps_completed=(ep + 1) * n_train, metrics={"val score": train_score[ep // freq]})
            except:
                print("load error")

    return model, epochs, checkpoint_id, list(train_score), list(train_losses), embedder_stats



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='UPS Eval Autoregressive')
    parser.add_argument('--config', type=str, default=None, help='config file name')

    args = parser.parse_args()
    if args.config is not None:     
        import yaml

        with open(args.config, 'r') as stream:
            args = AttrDict(yaml.safe_load(stream)['hyperparameters'])
            main(False, args)

    else:
        import determined as det
        from determined.experimental import client
        from determined.pytorch import DataLoader

        info = det.get_cluster_info()
        args = AttrDict(info.trial.hparams)
        
        with det.core.init() as context:
            main(True, args, info, context)