import numpy as np
import os.path as osp
import torch
import torch.nn as nn
import torch.distributed as dist
from tqdm import tqdm
from torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from ..tasks import build_task
from ..models import build_model
from . import register_pipe
from .base_pipe import BasePipe
from ..utils import Diffusion, plot_tsne
from ..utils.early_stop import EarlyStopping
from ..utils import get_pca_embedding, inverse_pca
from torch_pca import PCA

@register_pipe("ldm_pretrain")
class LDMPretrain(BasePipe):
    def __init__(self, args):
        super(LDMPretrain, self).__init__(args)
        self.encoder = build_model(args.model + "_Encoder").build_model_from_args(args).to(args.device)
        self.model = build_model(args.model).build_model_from_args(args).to(args.device)
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f'{total_params:,} total parameters.')
        total_trainable_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f'{total_trainable_params:,} training parameters.')
        parent_dir = osp.dirname(self._checkpoint)
        self.encoder_dir = osp.join(parent_dir, args.model + f"_encoder_{args.task}.pt")
        if args.load_from_pretrained:
            self.load_from_pretrained()
            if osp.exists(self.encoder_dir):
                ck_pt = torch.load(self.encoder_dir)
                self.encoder.load_state_dict(ck_pt)
        if hasattr(args, "compile_flag"):
            if args.compile_flag:
                self.compile()
                if hasattr(self.args, "compile"):
                    self.encoder = torch.compile(
                        self.encoder, mode=self.args.compile["mode"] if "mode" in self.args.compile else "default",
                        fullgraph=self.args.compile["fullgraph"] if "fullgraph" in self.args.compile else False,
                        dynamic=self.args.compile["dynamic"] if "dynamic" in self.args.compile else None
                        )
                else:
                    self.encoder = torch.compile(self.encoder)
        if args.use_distribute:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.device])
            self.encoder = nn.parallel.DistributedDataParallel(self.encoder, device_ids=[args.device])
        print("-----------------------load model done-----------------------")
        self.max_step = args.max_step
        self.autoencoder_optimizer = self.candidate_optimizer[args.optimizer](self.encoder.parameters(),
                                                                                lr=args.lr,
                                                                                weight_decay=args.weight_decay)
        self.optimizer = self.candidate_optimizer[args.optimizer](self.model.parameters(),
                                                                  lr=1e-4, weight_decay=args.weight_decay)

        task_name = "signal_prediction"
        self.task = build_task(args, task_name)
        self.loss_fn = self.task.get_loss_func()
        # train_dataset = self.task.get_pretrain_data()
        train_dataset, _, _ = self.task.get_data()
        self.train_loader = DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       pin_memory=False if args.device=="cpu" else True,
                                       sampler=DistributedSampler(train_dataset) if args.use_distribute else None,
                                       drop_last=True)
        self.diffusion = Diffusion(self.max_step, self.args.min_noise, self.args.max_noise, args.device)

        # self.scaler = torch.cuda.amp.GradScaler()
        self.autoencoder_scheduler = ReduceLROnPlateau(self.autoencoder_optimizer, 'min', factor=0.5, patience=3, verbose=True, min_lr=1e-8)
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', factor=0.5, patience=20, verbose=True, min_lr=1e-8)

    def train(self):
        # stopper = EarlyStopping(10, self.encoder_dir, 
        #                         self.args.compile_flag, self.args.use_distribute)
        # self.encoder.train()
        # # target = torch.tensor([]).to(self.device)

        # for epoch in range(self.args.num_epochs):
        #     for i, (data, _, y, snr) in enumerate(tqdm(self.train_loader)):
        #         data = data.transpose(1, 2)
        #         data = data.to(self.device, non_blocking=True)
        #         # indices = [index for index, element in enumerate(snr) if element >= 10]
        #         # target = data[indices]
        #         # if target.size(0) == 0:
        #         #     continue
        #         # target = torch.cat((target, data[indices]), dim=0)
        #         # target_label = torch.cat((target_label, y[indices]), dim=0)

        #         out = self.encoder(data)[0]
        #         loss = self.loss_fn(data, out)
        #         print(loss)
        #         self.autoencoder_optimizer.zero_grad()
        #         loss.backward()
        #         self.autoencoder_optimizer.step()
        #         self.autoencoder_scheduler.step(loss.item())
                
        #         if i % self.args.evaluate_interval == 0:
        #             print(f"---Encoder---loss:{loss.item()}---lr:{self.autoencoder_optimizer.param_groups[0]['lr']}---")
        #         early_stop = stopper.loss_step(loss, self.encoder)
        #         if early_stop:
        #             print("Early Stop!\tEpoch:" + str(epoch))
        #             break
        #     if early_stop:
        #         break
            
        # import sys
        # sys.exit()
        pca_model = PCA(self.args.encoder_latent_dim)
        stopper = EarlyStopping(self.args.patience, self._checkpoint, 
                                self.args.compile_flag, self.args.use_distribute)
        # self.encoder.eval()
        self.model.train()
        early_stop = torch.tensor(False, dtype=torch.bool, device=self.args.device)
        for epoch in range(self.args.num_epochs):
            for i, (data, _, y, _) in enumerate(tqdm(self.train_loader)):
                num_sample = data.size(0)
                data = data.transpose(1, 2)
                data = data.to(self.device, non_blocking=True)

                # mu, std = self.encoder.encode(data.unsqueeze(1))
                # z0 = self.encoder.reparameterize(mu, std)

                z0 = pca_model.fit_transform(data.reshape(num_sample, -1))
                # from ..utils.plot import plot_tsne
                # plot_tsne(z0.detach().to("cpu"), y.detach().to("cpu"), self.args.dataset[0], "all", "./", self.task.get_classes())
                # import sys
                # sys.exit()
                
                # output = pca_model.inverse_transform(z0)
                t = torch.randint(0, self.max_step, (num_sample, ), dtype=torch.int64).to("cpu")
                z, noise = self.diffusion.q_sample(z0, t)
                out = self.model(z, t)
                loss = self.loss_fn(noise, out)
                print(loss)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                # self.scaler.scale(loss).backward()
                # self.scaler.step(self.optimizer)
                # self.scaler.update()
                self.scheduler.step(loss.item())
                
                if i % self.args.evaluate_interval == 0:
                    print(f"---Model---loss:{loss.item()}---lr:{self.optimizer.param_groups[0]['lr']}---")
                early_stop = torch.tensor(stopper.loss_step(loss, self.model), dtype=torch.bool, device=self.args.device)
            #     if early_stop:
            #         # dist.all_reduce(early_stop, op=dist.ReduceOp.MAX)
            #         print("Early Stop!\tEpoch:" + str(epoch))
            #         break
            # if early_stop:
            #     break

        # dist.destroy_process_group()
