import torch
import torch.nn.functional as F
import os
import copy
import numpy as np
import datetime
from tqdm import tqdm
from utils import inf_iterator, enable_dropout, cal_stats_metric
from pathlib import Path
from omegaconf import OmegaConf
from prettytable import PrettyTable
from evaluator import Evaluator

mapdiff_aa = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS',
              'MET',
              'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL']

aa_dict = {
    "ALA": "A", "CYS": "C", "ASP": "D", "GLU": "E", "PHE": "F",
    "GLY": "G", "HIS": "H", "ILE": "I", "LYS": "K", "LEU": "L",
    "MET": "M", "ASN": "N", "PRO": "P", "GLN": "Q", "ARG": "R",
    "SER": "S", "THR": "T", "VAL": "V", "TRP": "W", "TYR": "Y"
}
idx2letter = {mapdiff_aa.index(k): v for k, v in aa_dict.items()}


class Trainer(object):
    def __init__(
            self,
            accelerator,
            config,
            diffusion_model,
            train_dataloader,
            val_dataloader,
            test_dataloader,
            optimizer,
            device,
            output_dir,
            training_epochs=100,
            scheduler=None,
            train_batch_size=512,
            train_num_steps=200000,
            save_and_sample_every=100,
            num_samples=25,
            ensemble_num=50,
            ddim_steps=50,
            sample_method='ddim',
            experiment=None
    ):
        super().__init__()
        self.accelerator = accelerator
        self.device = device
        self.model = diffusion_model.to(self.device)
        self.config = config
        self.num_samples = num_samples
        self.ensemble_num = ensemble_num
        self.ddim_steps = ddim_steps
        self.save_and_sample_every = save_and_sample_every

        self.batch_size = train_batch_size
        self.training_epoch = training_epochs

        # self.train_num_steps = 100
        self.train_num_steps = train_num_steps
        self.sample_method = sample_method

        # dataset and dataloader
        self.train_dataloader = train_dataloader
        self.iter_one_epoch = len(train_dataloader)
        self.train_iterator = inf_iterator(train_dataloader)
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader
        # optimizer

        self.optimizer = optimizer
        self.scheduler = scheduler
        self.evaluator = Evaluator()
        self.best_val_step = 0
        self.best_val_epoch = 0
        self.step = 0
        self.epoch = 0
        self.best_val_recovery, self.best_val_perplexity = 0, float('inf')
        self.best_model = None

        self.train_metric_header = ["# Epoch", "# Step", "Train_loss"]
        self.val_metric_header = ["# Epoch", "# Step", "Recovery", "Perplexity"]
        self.test_metric_header = ["# Epoch", "# Step", "Recovery", "Perplexity"]
        self.train_table = PrettyTable(self.train_metric_header)
        self.val_table = PrettyTable(self.val_metric_header)
        self.test_table = PrettyTable(self.test_metric_header)

        self.results_folder = output_dir
        Path(self.results_folder + '/model/').mkdir(exist_ok=True)
        self.experiment = experiment

    def save(self, accelerator, save_epochs, save_steps, mode='best'):
        config_dict = OmegaConf.to_container(self.config, resolve=True)
        if mode == 'best':
            data = {
                'config': config_dict,
                'step': save_steps,
                'epoch': save_epochs,
                'model': accelerator.unwrap_model(self.best_model).state_dict(),
                # 'opt': self.optimizer.state_dict(),
            }
        elif mode == 'last':
            data = {
                'config': config_dict,
                'step': save_steps,
                'epoch': save_epochs,
                'model': accelerator.unwrap_model(self.model).state_dict(),
                # 'opt': self.optimizer.state_dict(),
            }
        else:
            raise ValueError(f"unknown mode {mode}")
        save_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        torch.save(data, os.path.join(self.results_folder, 'model',
                                      f'{self.config.experiment.name}_{mode}_{save_epochs}_epochs_{save_steps}_steps_{save_time}.pt'))

    def save_table_results(self):
        with open(os.path.join(self.results_folder, 'train_markdowntable.txt'), 'w') as f:
            f.write(self.train_table.get_string())
        with open(os.path.join(self.results_folder, 'val_markdowntable.txt'), 'w') as f:
            f.write(self.val_table.get_string())
        with open(os.path.join(self.results_folder, 'test_markdowntable.txt'), 'w') as f:
            f.write(self.test_table.get_string())

    def train(self):
        # torch.autograd.set_detect_anomaly(True)
        epoch_total_loss = 0
        base_total_loss, mask_total_loss = 0, 0
        # with tqdm(initial=self.step, total=self.train_num_steps) as pbar:

        for epoch in range(self.training_epoch):
            for g_batch, ipa_batch in tqdm(self.train_dataloader):
                self.model.train()
                g_batch, ipa_batch = g_batch.to(self.device), ipa_batch.to(
                    self.device) if ipa_batch is not None else None
                base_loss = self.model(g_batch, ipa_batch)
                loss = base_loss
                # loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.accelerator.backward(loss)
                if self.accelerator.sync_gradients:
                    self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)

                self.optimizer.step()
                if self.scheduler:
                    self.scheduler.step()
                self.optimizer.zero_grad()

                self.step += 1

                epoch_total_loss += loss.item()

            if self.accelerator.is_main_process:
                self.epoch += 1
                self.train_table.add_row([self.epoch, self.step, epoch_total_loss / self.iter_one_epoch])
                self.experiment.log(f"train_loss: {epoch_total_loss / self.iter_one_epoch}, epoch: {self.epoch}")

            epoch_total_loss = 0
            base_total_loss = 0
            # mask_total_loss = 0
            torch.cuda.empty_cache()
            # TODO:self.accelerator.is_main_process and self.step!=0 and self.step % 50 == 0:  #
            if (epoch + 1) % self.save_and_sample_every == 0:
                self.model.eval()
                enable_dropout(self.model)
                with torch.no_grad():
                    all_logits = torch.tensor([])
                    all_seq = torch.tensor([])
                    recovery = []
                    for g_batch, ipa_batch in tqdm(self.val_dataloader):
                        g_batch, ipa_batch = g_batch.to(self.device), ipa_batch.to(
                            self.device) if ipa_batch is not None else None
                        ens_logits = []
                        if self.sample_method == 'ddim':
                            for _ in range(self.ensemble_num):
                                logits, sample_graph = self.accelerator.unwrap_model(self.model).mc_ddim_sample(
                                    g_batch, ipa_batch, diverse=True,
                                    step=self.ddim_steps)
                                ens_logits.append(logits)
                        ens_logits_tensor = torch.stack(ens_logits)
                        batch_logits = ens_logits_tensor.mean(dim=0).cpu()
                        all_logits = torch.cat([all_logits, batch_logits])
                        all_seq = torch.cat([all_seq, g_batch.x.cpu()])

                        g_batch_list = self.accelerator.gather_for_metrics((batch_logits, g_batch, all_logits))
                        # given all_logits and g_batch.batch, calculate recovery rate for each sequence
                        for i in range(len(g_batch_list) // 3):
                            batch_logits = g_batch_list[3 * i]
                            g_batch = g_batch_list[3 * i + 1]
                            all_logits = g_batch_list[3 * i + 2]
                            batch_idx = g_batch.batch.cpu().numpy()
                            for i in range(batch_idx.max() + 1):
                                idx = np.where(batch_idx == i)
                                sample_logits = batch_logits[idx].argmax(dim=1)
                                sample_seq = g_batch.x.cpu()[idx].argmax(dim=1)
                                sample_recovery = self.evaluator.cal_recovery(sample_logits, sample_seq)
                                recovery.append(sample_recovery)

                    mean_recovery, median_recovery = cal_stats_metric(recovery)

                    print(f'Val median recovery rate (epoch: {self.epoch}) is {median_recovery}')

                    if self.accelerator.is_main_process:
                        self.experiment.log(f"val_median_recovery: {median_recovery}, epoch: {self.epoch}")
                        self.experiment.log(f"val_mean_recovery: {mean_recovery}, epoch: {self.epoch}")

                    if median_recovery > self.best_val_recovery:
                        self.best_model = copy.deepcopy(self.model)
                        self.best_val_step = self.step
                        self.best_val_epoch = self.epoch
                        self.best_val_recovery = median_recovery
                        # self.best_val_perplexity = perplexity
        if self.accelerator.is_main_process:
            print('Training complete')
            if self.experiment:
                self.experiment.log(f"best_val_median_recovery: {self.best_val_recovery}")
                self.experiment.log(f"best_val_perplexity: {self.best_val_perplexity}")
                self.experiment.log(f"best_val_epoch: {self.best_val_epoch}")

            self.save(self.accelerator, self.best_val_epoch, self.best_val_step, mode='best')
            self.save(self.accelerator, self.epoch, self.train_num_steps, mode='last')

    def test(self):
        import csv
        self.best_model.eval()
        enable_dropout(self.best_model)
        with torch.no_grad():
            print('Testing best model')
            all_logits = torch.tensor([])
            all_seq = torch.tensor([])

            recovery = []
            nssr42, nssr62, nssr80, nssr90 = [], [], [], []
            with open('data_cath42.csv', 'w', newline='') as file:
                writer = csv.writer(file)
                writer.writerow(['Native', 'Prediction'])
                for g_batch, ipa_batch in tqdm(self.test_dataloader):
                    g_batch, ipa_batch = g_batch.to(self.device), ipa_batch.to(
                        self.device) if ipa_batch is not None else None
                    ens_logits = []
                    if self.sample_method == 'ddim':
                        for _ in range(self.ensemble_num):
                            logits, sample_graph = self.accelerator.unwrap_model(self.best_model).mc_ddim_sample(g_batch,
                                                                                                                 ipa_batch,
                                                                                                                 diverse=True,
                                                                                                                 step=self.ddim_steps)
                            ens_logits.append(logits)
                    ens_logits_tensor = torch.stack(ens_logits)
                    batch_logits = ens_logits_tensor.mean(dim=0).cpu()
                    all_logits = torch.cat([all_logits, batch_logits])
                    all_seq = torch.cat([all_seq, g_batch.x.cpu()])

                    batch_idx = g_batch.batch.cpu().numpy()
                    for i in range(batch_idx.max() + 1):
                        idx = np.where(batch_idx == i)
                        sample_logits = batch_logits[idx].argmax(dim=1)
                        sample_seq = g_batch.x.cpu()[idx].argmax(dim=1)
                        pred_seq = "".join(idx2letter[_idx.item()] for _idx in sample_logits)
                        native_seq = "".join(idx2letter[_idx.item()] for _idx in sample_seq)

                        writer.writerow([native_seq,pred_seq])

                        sample_recovery = self.evaluator.cal_recovery(sample_logits, sample_seq)
                        sample_nssr42, sample_nssr62, sample_nssr80, sample_nssr90 = self.evaluator.cal_all_blosum_nssr(
                            sample_logits, sample_seq)
                        nssr42.append(sample_nssr42)
                        nssr62.append(sample_nssr62)
                        nssr80.append(sample_nssr80)
                        nssr90.append(sample_nssr90)
                        recovery.append(sample_recovery)

            test_mean_recovery, test_median_recovery = cal_stats_metric(recovery)
            test_mean_nssr42, test_median_nssr42 = cal_stats_metric(nssr42)
            test_mean_nssr62, test_median_nssr62 = cal_stats_metric(nssr62)
            test_mean_nssr80, test_median_nssr80 = cal_stats_metric(nssr80)
            test_mean_nssr90, test_median_nssr90 = cal_stats_metric(nssr90)

            test_recovery = (all_logits.argmax(dim=1) == all_seq.argmax(dim=1)).sum() / all_seq.shape[0]
            test_recovery = test_recovery.item()
            test_perplexity = self.evaluator.cal_perplexity(all_logits, all_seq)
            print(f'test median recovery rate with best model (step: {self.best_val_step}) is {test_median_recovery}')
            print(f'test perplexity with the best model (step: {self.best_val_step}) is: {test_perplexity}')
            print(f'test_median_nssr42_with_best_model: {test_median_nssr42}')
            print(f'test_median_nssr62_with_best_model: {test_median_nssr62}')
            print(f'test_median_nssr80_with_best_model: {test_median_nssr80}')
            print(f'test_median_nssr90_with_best_model: {test_median_nssr90}')

