from multiprocessing.synchronize import Barrier as BarrierClass
from pathlib import Path

import pandas as pd
import torch
import torch.distributed as dist
from loguru import logger

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.core.enums import BenchmarkName, DataSplit, DownstreamTask, Phase, Task
from tabicl.core.get_loss import get_loss_pretrain
from tabicl.core.get_model import get_model_pretrain
from tabicl.core.get_optimizer import get_optimizer_pretrain
from tabicl.core.get_scheduler import get_scheduler_pretrain
from tabicl.core.grad_scaler import GradScaler
from tabicl.core.metrics import MetricsFewSteps, MetricsFullRun
from tabicl.core.metrics_plot import plot_loss
from tabicl.core.trainer_pretrain_evaluate import create_config_benchmark_sweep
from tabicl.core.trainer_pretrain_init import (create_synthetic_dataloader, log_flops, log_parameter_count,
                                               prepare_ddp_model)
from tabicl.core.utils import move_optimizer_to
from tabicl.data.benchmarks import BENCHMARKS
from tabicl.data.dataset_synthetic import SyntheticDataset
from tabicl.sweeps.run_sweep import run_sweep
from tabicl.utils.paths_and_filenames import DEFAULT_RESULTS_TEST_FILE_NAME, DEFAULT_RESULTS_VAL_FILE_NAME


class TrainerPretrain():

    def __init__(
            self, 
            cfg: ConfigPretrain,
            barrier: BarrierClass
        ) -> None:

        self.cfg = cfg
        self.barrier = barrier
        self.model_base = get_model_pretrain(cfg).to(cfg.device)

        log_parameter_count(cfg, self.model_base)
        log_flops(cfg)
        self.model_ddp = prepare_ddp_model(cfg, self.model_base)

        self.synthetic_dataset = SyntheticDataset(cfg)
        self.synthetic_dataloader = create_synthetic_dataloader(cfg, self.synthetic_dataset)

        self.optimizer = get_optimizer_pretrain(cfg, self.model_ddp)
        self.scheduler = get_scheduler_pretrain(cfg, self.optimizer)
        self.scaler = GradScaler(
            enabled=cfg.optim.grad_scaler.enabled,
            scale_init=cfg.optim.grad_scaler.scale_init,
            scale_min=cfg.optim.grad_scaler.scale_min,
            growth_interval=cfg.optim.grad_scaler.growth_interval
        )
        self.loss = get_loss_pretrain(cfg)

        self.metrics_few_steps = MetricsFewSteps(cfg.data.task)
        self.metrics_full_run = MetricsFullRun(cfg.data.task)

        self.step = 0
        self.dataloader = None
        

    def train(self):

        self.model_ddp.train()
        self.dataloader = iter(self.synthetic_dataloader)

        for step in range(1, self.cfg.optim.steps+1):
            self.step = step
            
            preds = []
            y_queries = []

            self.optimizer.zero_grad(set_to_none=True)

            for _ in range(self.cfg.optim.gradient_accumulation_steps):

                with torch.autocast(device_type="cuda", dtype=getattr(torch, self.cfg.optim.precision)):

                    dataset: dict[str, torch.Tensor] = next(self.dataloader)

                    x_support = dataset['x_support'].to("cuda", non_blocking=True)
                    y_support = dataset['y_support'].to("cuda", non_blocking=True)
                    x_query = dataset['x_query'].to("cuda", non_blocking=True)
                    y_query = dataset['y_query'].to("cuda", non_blocking=True)
                    padding_features = dataset['padding_features'].to("cuda", non_blocking=True)
                    padding_obs_support = dataset['padding_obs_support'].to("cuda", non_blocking=True)
                    padding_obs_query = dataset['padding_obs_query'].to("cuda", non_blocking=True)

                    pred = self.model_ddp(x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query)
                    loss = self.loss(pred, y_query)
                    loss = loss / self.cfg.optim.gradient_accumulation_steps

                self.scaler.scale(loss).backward()

                preds.append(pred.detach().cpu())     
                y_queries.append(y_query.detach().cpu())

            pred = torch.cat(preds, dim=0).float()  # float because it could be half precision
            y_query = torch.cat(y_queries, dim=0)

            self.scaler.unscale_(self.optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(self.model_ddp.parameters(), self.cfg.optim.max_grad_norm).item()
            self.scaler.step(self.optimizer)
            self.scheduler.step()
            self.scaler.update()

            if dist.get_rank() == 0:
                self.update_metrics(pred, y_query, self.scaler, grad_norm)

            if self.log_this_step() and dist.get_rank() == 0:
                self.log_training_metrics()
                self.metrics_full_run.save(self.cfg.output_dir)

            if self.eval_this_step():
                # Remove all references to the tensors to free up memory for evaluation
                del pred, y_query, loss, dataset, x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query
                self.move_model_to_cpu()
                self.evaluate_current_model()
                self.wait_and_move_model_to_gpu()

        self.move_model_to_cpu()
        self.save_weights()
    
    
    def evaluate_current_model(self):

        if dist.get_rank() != 0:
            return

        output_dir = self.prepare_output_dir()
        weights_path = self.save_weights()

        for benchmark in self.cfg.testing.benchmarks_valid:
            for task in self.cfg.testing.downstream_tasks:

                logger.info(f"Starting validation sweep: {benchmark.value} {task.value}")

                output_dir_task = output_dir / f"{benchmark.value}_{task.value}"
                plot_name = f"{self.cfg.model_name.value}_{task.value} Pretrain Step {self.step}"
                normalized_scores = self.validate(
                    output_dir_task, 
                    weights_path, 
                    plot_name=plot_name, 
                    benchmark=benchmark, 
                    task=task
                )

                logger.info(f"Finished validation sweep: {benchmark.value} {task.value}")
                logger.info(f"Normalized Validation Accuracy: {normalized_scores[DataSplit.VALID]:.4f}")
                logger.info(f"Normalized Test Accuracy: {normalized_scores[DataSplit.TEST]:.4f}")

                self.metrics_full_run.update_val(normalized_scores[DataSplit.VALID], normalized_scores[DataSplit.TEST], self.step, task=task)
        
        self.metrics_full_run.save(self.cfg.output_dir)
        plot_loss(self.cfg)


    def log_this_step(self):
        return self.step % self.cfg.optim.log_every_n_steps == 0
    

    def log_training_metrics(self):

        if self.cfg.data.task == Task.REGRESSION:
            logger.info(f"Step {self.step} | MSE: {self.metrics_few_steps.loss:.4f} | MAE: {self.metrics_few_steps.loss2:.4f} | R2: {self.metrics_few_steps.score:.4f}")
        else:
            logger.info(f"Step {self.step} | CrossEntropy: {self.metrics_few_steps.loss:.4f} | Accuracy: {self.metrics_few_steps.score:.4f}")

        self.metrics_few_steps.reset()


    def eval_this_step(self):
        return self.step % self.cfg.optim.eval_every_n_steps == 0


    def update_metrics(self, pred: torch.Tensor, y_query: torch.Tensor, scaler: torch.cuda.amp.GradScaler, grad_norm: float):
        
        self.metrics_few_steps.update(pred, y_query)
        self.metrics_full_run.update(pred, y_query, scaler, grad_norm)


    def move_model_to_cpu(self):
        # For some reason, something is still on the gpu after running this. But what?
        # I guess DDP cannot be removed from the GPU.

        self.model_ddp = self.model_ddp.to('cpu')
        self.model_base = self.model_base.to('cpu')
        del self.model_ddp
        move_optimizer_to(self.optimizer, 'cpu')
        torch.cuda.empty_cache()

        # wait until all gpus have moved the model to cpu
        self.barrier.wait()


    def wait_and_move_model_to_gpu(self):

        # We cannot use the torch distributed barrier here, because that blocks the execution on the gpus.
        # This barrier only blocks execution on the cpu of the current process, which doesn't interfere with the validation sweep.
        self.barrier.wait()

        # see https://github.com/pytorch/pytorch/issues/104336
        self.model_base = self.model_base.to("cuda")
        self.model_ddp = torch.nn.parallel.DistributedDataParallel(self.model_base, device_ids=[torch.cuda.current_device()], find_unused_parameters=False)
        move_optimizer_to(self.optimizer, "cuda")


    def prepare_output_dir(self) -> Path:
        
        output_dir = self.cfg.output_dir / f"step_{self.step}"
        output_dir.mkdir(parents=True, exist_ok=True)
        return output_dir


    def save_weights(self) -> Path:

        weights_path = self.cfg.output_dir / 'weights' / f"model_step_{self.step}.pt"
        weights_path.parent.mkdir(parents=True, exist_ok=True)
        self.last_weights_path = weights_path

        state_dict = self.model_base.state_dict()
        torch.save(state_dict, weights_path)

        return weights_path


    def validate(
        self, 
        output_dir: Path, 
        weights_path: Path, 
        plot_name: str, 
        benchmark: BenchmarkName, 
        task: DownstreamTask
    ) -> dict[DataSplit, float]:

        cfg_sweep = create_config_benchmark_sweep(
            cfg=self.cfg,
            benchmark=BENCHMARKS[benchmark],
            output_dir=output_dir,
            weights_path=weights_path,
            plot_name=plot_name,
            phase=Phase.VALIDATION,
            task=task
        )
        run_sweep(cfg_sweep)

        default_results_val = pd.read_csv(output_dir / DEFAULT_RESULTS_VAL_FILE_NAME, index_col=0)
        normalized_score_val = default_results_val.loc[plot_name].iloc[-1]

        default_results_test = pd.read_csv(output_dir / DEFAULT_RESULTS_TEST_FILE_NAME, index_col=0)
        normalized_score_test = default_results_test.loc[plot_name].iloc[-1]

        return {
            DataSplit.VALID: normalized_score_val,
            DataSplit.TEST: normalized_score_test
        }
    



    