import math
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from ..tasks import build_task
from ..models import build_model
from . import register_pipe
from .base_pipe import BasePipe
from ..utils import SignalDiffusion
from ..utils.early_stop import EarlyStopping

@register_pipe("rf_ddm_pretrain")
class RFDDMPretrain(BasePipe):
    def __init__(self, args):
        super(RFDDMPretrain, self).__init__(args)
        self.model = build_model(args.model).build_model_from_args(args).to(args.device)
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f'{total_params:,} total parameters.')
        total_trainable_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f'{total_trainable_params:,} training parameters.')
        if args.load_from_pretrained:
            self.load_from_pretrained()
        if hasattr(args, "compile_flag"):
            if args.compile_flag:
                self.compile()
        if args.use_distribute:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.device])
        print("-----------------------load model done-----------------------")
        self.max_step = args.max_step
        self.optimizer = self.candidate_optimizer[args.optimizer](self.model.parameters(),
                                                                  lr=args.lr, weight_decay=args.weight_decay)
        
        task_name = "signal_prediction"
        self.task = build_task(args, task_name)
        self.loss_fn = self.task.get_loss_func()
        train_dataset = self.task.get_pretrain_data()
        self.train_loader = DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       pin_memory=False if args.device=="cpu" else True,
                                       sampler=DistributedSampler(train_dataset) if args.use_distribute else None,
                                       drop_last=True)
        self.diffusion = SignalDiffusion(self.args.batch_size, self.args.extra_dim, self.max_step,
                                    self.args.blur_noise, self.args.min_noise, self.args.max_noise)

        # self.scaler = torch.cuda.amp.GradScaler()
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', factor=0.5, patience=3, verbose=True, min_lr=1e-8)

    def train(self):
        stopper = EarlyStopping(self.args.patience, self._checkpoint, 
                                self.args.compile_flag, self.args.use_distribute)
        self.model.train()
        early_stop = torch.tensor(False, dtype=torch.bool, device=self.args.device)
        for epoch in range(self.args.num_epochs):
            for i, (data, _, _, _) in enumerate(self.train_loader):
                data = data.transpose(1, 2)
                data = data.to(self.device, non_blocking=True)
                t = torch.randint(0, self.max_step, (1,), dtype=torch.int64).to("cpu")
                x = self.diffusion.degrade_fn(data, t)
                out = self.model(x, t)
                loss = self.loss_fn(data, out) * 10000
                print(loss)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                # self.scaler.scale(loss).backward()
                # self.scaler.step(self.optimizer)
                # self.scaler.update()
                # self.scheduler.step(loss.item())
                
                if i % self.args.evaluate_interval == 0:
                    print(f"---loss:{loss.item()}---lr:{self.optimizer.param_groups[0]['lr']}---")
                early_stop = torch.tensor(stopper.loss_step(loss, self.model), dtype=torch.bool, device=self.args.device)
                if early_stop:
                    dist.all_reduce(early_stop, op=dist.ReduceOp.MAX)
                    print("Early Stop!\tEpoch:" + str(epoch))
                    break
            if early_stop:
                break

        dist.destroy_process_group()
