import os.path as osp
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from ..tasks import build_task
from ..models import build_model
from . import register_pipe
from .base_pipe import BasePipe
from ..utils.early_stop import EarlyStopping
from ..utils.plot import plot_confusion_matrix, plot_tsne
from ..utils import Diffusion, get_pca_embedding, inverse_pca
from torch_pca import PCA

@register_pipe("ldm_tune")
class LDMTune(BasePipe):
    def __init__(self, args):
        super(LDMTune, self).__init__(args)
        self.model_name = args.model
        self.encoder = build_model(args.model + "_Encoder").build_model_from_args(args).to(args.device)
        self.model = build_model(args.model).build_model_from_args(args).to(args.device)
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f'{total_params:,} total parameters.')
        total_trainable_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f'{total_trainable_params:,} training parameters.')
        parent_dir = osp.dirname(self._checkpoint)
        self.encoder_dir = osp.join(parent_dir, args.model + f"_encoder_{args.dataset[0]}_{args.task}.pt")
        # self.model = torch.compile(self.model, mode="max-autotune")
        if args.load_from_pretrained:
            self.load_from_pretrained()
            if osp.exists(self.encoder_dir):
                print("loaded!")
                ck_pt = torch.load(self.encoder_dir)
                self.encoder.load_state_dict(ck_pt)
        if hasattr(args, "compile_flag"):
            if args.compile_flag:
                self.compile()
                if hasattr(self.args, "compile"):
                    self.encoder = torch.compile(
                        self.encoder, mode=self.args.compile["mode"] if "mode" in self.args.compile else "default",
                        fullgraph=self.args.compile["fullgraph"] if "fullgraph" in self.args.compile else False,
                        dynamic=self.args.compile["dynamic"] if "dynamic" in self.args.compile else None
                        )
                else:
                    self.encoder = torch.compile(self.encoder)
        if args.use_distribute:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.device])
            self.encoder = nn.parallel.DistributedDataParallel(self.encoder, device_ids=[args.device])

        print("-----------------------load model done-----------------------")

        task_name = "modulation_classification"
        self.task = build_task(args, task_name)
        self.loss_fn = self.task.get_loss_func()
        train_dataset, val_dataset, test_dataset = self.task.get_data()
        self.classes = self.task.get_classes()
        self.diffusion = Diffusion(args.max_step, args.min_noise, args.max_noise, args.device)

        self.classifier = nn.Linear(args.encoder_latent_dim, len(self.classes)).to(args.device)
        self.optimizer = self.candidate_optimizer[args.optimizer](self.classifier.parameters(),
                                                                    lr=1e-4, weight_decay=args.weight_decay)
        self.train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False)
        self.val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
        self.test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)

        self.scaler = torch.cuda.amp.GradScaler()
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', factor=0.5, patience=3, verbose=True, min_lr=5e-5)
        self.SNR_list = list(range(args.minSNR, args.maxSNR + 1, 2))
        self.output_dir = args.output_dir
        if not hasattr(args, "plot"):
            self.plot = False
        else:
            self.plot = args.plot
        
        self.pca_model = PCA(self.args.encoder_latent_dim)

    def train(self):
        stopper = EarlyStopping(self.args.patience, self._checkpoint, self.args.compile_flag, self.args.use_distribute)
        iters_per_epoch = len(self.train_loader)
        for epoch in range(self.args.num_epochs):
            self.model.train()
            train_loss, train_acc, train_true, train_pred = self._train_step()
            print(f"Epoch:{epoch}, train_loss={train_loss}, train_acc={train_acc['Avg']}")
            print(train_acc)
            if self.plot:
                plot_confusion_matrix(train_true, train_pred, self.dataset_name, "all", self.output_dir, self.classes)

            if epoch % self.args.evaluate_interval == 0:
                loss, acc, true, pred, val_snr = self._test_step("val")
                print(f"Epoch:{epoch}, val_loss={loss}, val_acc={acc['Avg']}")
                early_stop = stopper.loss_step(loss, self.model)
                if self.plot:
                    plot_confusion_matrix(true, pred, self.dataset_name, "all", self.output_dir, self.classes)

            if early_stop:
                print("Early Stop!\tEpoch:" + str(epoch))
                break

        stopper.load_model(self.model)
        test_loss, test_acc, test_true, test_pred, test_SNR = self._test_step("test")
        performance = test_acc["Avg"]
        print(f"test acc={performance}")
        print(test_acc)
        if self.plot:
            mod_dic = {}
            for snr in self.SNR_list:
                SNR_cm = [i for i in zip(test_SNR, pred, true) if i[0] == snr]
                true_cm = []
                pred_cm = []
                true_cls = np.zeros(len(self.classes))
                all = np.zeros(len(self.classes))
                for i in SNR_cm:
                    pred_cm.append(i[1])
                    true_cm.append(i[2])
                    if i[1] == i[2]:
                        true_cls[i[1]] = true_cls[i[1]] + 1
                        all[i[1]] = all[i[1]] + 1
                    else:
                        all[i[2]] = all[i[2]] + 1
                cls_acc = {cls: x / y for cls, x, y in zip(self.classes, true_cls, all)}
                mod_dic[snr] = list(cls_acc.values())
                # modacc = pd.DataFrame.from_dict(mod_dic, orient='index', columns=classes).reset_index(names='SNR')
                # modacc.to_csv(f'logs/{model_tag}/Test_mod_SNR.csv', index=False)
                plot_confusion_matrix(test_acc, test_true, self.dataset_name, snr, self.output_dir, self.classes)
                SNR_tsne_ = [i for i in
                        zip(test_SNR, torch.stack(test_pred).cpu().data.numpy(), torch.stack(test_true).cpu().data.numpy()) if
                        i[0] == snr]
                _, pred_0, true_0 = zip(*SNR_tsne_)
                plot_tsne(np.array(list(pred_0)), np.array(list(true_0)), self.dataset_name, snr, self.output_dir, self.classes)

    def _train_step(self):
        SNR = dict([(key, 0) for key in self.SNR_list])
        SNR_true = dict([(key, 0) for key in self.SNR_list])
        y_true = []
        y_pred = []
        num_total = 0
        loss = 0.0
        ssim_list = []

        target = torch.tensor([]).to(self.args.device)
        target_label = torch.tensor([]).to(self.args.device)
        total = 0
        
        for _, data in enumerate(tqdm(self.train_loader)):
            batch_x, batch_stft, batch_y, batch_SNR = data
            num_sample = batch_x.size(0)
            num_total += num_sample
            batch_SNR = batch_SNR.numpy().tolist()
            batch_y = batch_y.to(self.device)
            batch_x = batch_x.transpose(1, 2).to(self.device)
            z0 = self.pca_model.fit_transform(batch_x.reshape(num_sample, -1))
            
            
            # output, _, _ = get_pca_embedding(batch_x.detach().numpy(), self.args.encoder_latent_dim)
            # output = torch.tensor(output).to(self.device)

            
            # mu, std = self.encoder.encode(batch_x.unsqueeze(1))
            # z0 = self.encoder.reparameterize(mu, std)
            # output = self.encoder.decode(z0)

            # from ..utils import plot_signal
            # plot_signal(batch_x[0].detach().cpu(), "./")
            # plot_signal(output[0].detach().cpu(), "./ddm4signal")

            
            
            # from ..utils.plot import plot_tsne
            # indices = [index for index, element in enumerate(batch_SNR) if element == 18]
            # target = torch.cat((target, z0[indices]), dim=0)
            # target_label = torch.cat((target_label, batch_y[indices]), dim=0)
            # continue
            # if target.size(0) < 512:
            #     continue
            # else:

            # plot_tsne(output.to("cpu"), batch_y.detach().to("cpu"), self.args.dataset[0], "all", "./", self.classes)

            # import sys
            # sys.exit()
            z = z0
            # with torch.no_grad():
            #     for t in reversed(range(self.args.max_step)):
            #         z = self.diffusion.p_sample(self.model, z, t)
            #         # print(t)
            #         torch.cuda.empty_cache()
                    # print(f"Memory allocated: {torch.cuda.memory_allocated()}")
                    # print(f"Memory reserved: {torch.cuda.memory_reserved()}")

            batch_out = self.classifier(z)

            batch_loss = self.loss_fn(batch_out, batch_y)
            print(batch_loss)

            train_pred = batch_out.cpu().detach().numpy()
            train_pred = train_pred.argmax(1).tolist()
            train_true = batch_y.cpu().detach().numpy().tolist()

            y_true.extend(train_true)
            y_pred.extend(train_pred)

            for slice in range(num_sample):
                if (type(batch_SNR[slice])).__name__ == 'list':
                    batch_SNR[slice] = batch_SNR[slice][0]
                if train_pred[slice] == train_true[slice]:
                    SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1
                    SNR_true[batch_SNR[slice]] = SNR_true.get(batch_SNR[slice]) + 1
                else:
                    SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1

            loss += (batch_loss.item() * num_sample)
            self.optimizer.zero_grad()
            batch_loss.backward()
            self.optimizer.step()

        loss /= num_total
        avg_true = 0
        avg_all = 0
        for key in self.SNR_list:
            avg_all += SNR[key]
            avg_true += SNR_true[key]
            SNR[key] = SNR_true[key] / float(SNR[key])
        SNR['Avg'] = avg_true / float(avg_all)

        
        return loss, SNR, y_true, y_pred

    def _test_step(self, mode):
        self.model.eval()
        SNR = dict([(key, 0) for key in self.SNR_list])
        SNR_true = dict([(key, 0) for key in self.SNR_list])
        y_true = []
        y_pred = []
        eval_SNR = []
        
        num_total = 0
        loss = 0.0
        if "val" in mode:
            loader = self.val_loader
        elif "test" in mode:
            loader = self.test_loader

        with torch.no_grad():
            for _, data in enumerate(loader):
                batch_x, batch_stft, batch_y, batch_SNR = data
                num_sample = batch_x.size(0)
                num_total += num_sample
                batch_x = batch_x.to(self.device)
                batch_stft = batch_stft.to(self.device)           
                batch_SNR = batch_SNR.numpy().tolist()
                batch_y = batch_y.to(self.device)

                batch_x = batch_x.transpose(1, 2).to(self.device)
                mu, std = self.encoder.encode(batch_x.unsqueeze(1))
                z0 = self.encoder.reparameterize(mu, std)
                for t in reversed(range(self.args.max_step)):
                    z = self.diffusion.p_sample(self.model, z0, t)
                batch_out = self.classifier(z)

                batch_loss = self.loss_fn(batch_out, batch_y)
                train_pred = batch_out.cpu().detach().numpy()
                train_pred = train_pred.argmax(1).tolist()
                train_true = batch_y.cpu().detach().numpy().tolist()
                y_true.extend(train_true)
                y_pred.extend(train_pred)
                eval_SNR.extend(batch_SNR)

                for slice in range(num_sample):
                    if isinstance(batch_SNR[slice], list):
                        batch_SNR[slice] = batch_SNR[slice][0]
                    if train_pred[slice] == train_true[slice]:
                        SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1
                        SNR_true[batch_SNR[slice]] = SNR_true.get(batch_SNR[slice]) + 1
                    else:
                        SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1

                loss += (batch_loss.item() * num_sample)

        loss /= num_total
        avg_true = 0
        avg_all = 0
        for key in self.SNR_list:
            avg_all += SNR[key]
            avg_true += SNR_true[key]
            SNR[key] = SNR_true[key] / float(SNR[key])
        SNR['Avg'] = avg_true / float(avg_all)
        
        return loss, SNR, y_true, y_pred, eval_SNR

import math
@torch.jit.script
def gaussian(window_size: int, tfdiff: float):
    gaussian = torch.tensor([math.exp(-(x - window_size//2)**2/float(2*tfdiff**2)) for x in range(window_size)])
    return gaussian / gaussian.sum()

@torch.jit.script
def create_window(height: int, width: int):
    h_window = gaussian(height, 1.5).unsqueeze(1)
    w_window = gaussian(width, 1.5).unsqueeze(1)
    _2D_window = h_window.mm(w_window.t()).unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(1, 1, height, width).contiguous()
    return window


def eval_ssim(pred, data, height, width, device):
    window = create_window(height, width).to(torch.complex64).to(device)
    padding = [height//2, width//2]
    mu_pred = torch.nn.functional.conv2d(pred, window, padding=padding, groups=1)
    mu_data = torch.nn.functional.conv2d(data, window, padding=padding, groups=1)
    mu_pred_pow = mu_pred.pow(2.)
    mu_data_pow = mu_data.pow(2.)
    mu_pred_data = mu_pred * mu_data
    tfdiff_pred = torch.nn.functional.conv2d(pred*pred, window, padding=padding, groups=1) - mu_pred_pow
    tfdiff_data = torch.nn.functional.conv2d(data*data, window, padding=padding, groups=1) - mu_data_pow
    tfdiff_pred_data = torch.nn.functional.conv2d(pred*data, window, padding=padding, groups=1) - mu_pred_data
    C1 = 0.01**2
    C2 = 0.03**2
    ssim_map = ((2*mu_pred*mu_data+C1) * (2*tfdiff_pred_data.real+C2)) / ((mu_pred_pow+mu_data_pow+C1)*(tfdiff_pred+tfdiff_data+C2))
    return 2*ssim_map.mean().real