import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


OPAMP_TARGET_CORNERS = ('tt_t0.0_v1.0',
                  'tt_t0.0_v1.1',
                  'tt_t0.0_v1.2',
                  'tt_t100.0_v1.0',
                  'tt_t100.0_v1.1',
                  'tt_t100.0_v1.2',
                  'ss_t0.0_v1.0',
                  'ss_t0.0_v1.1',
                  'ss_t0.0_v1.2',
                  'ss_t100.0_v1.0',
                  'ss_t100.0_v1.1',
                  'ss_t100.0_v1.2',
                  'ff_t0.0_v1.0',
                  'ff_t0.0_v1.1',
                  'ff_t0.0_v1.2',
                  'ff_t100.0_v1.0',
                  'ff_t100.0_v1.1',
                  'ff_t100.0_v1.2',
                  'sf_t0.0_v1.0',
                  'sf_t0.0_v1.1',
                  'sf_t0.0_v1.2',
                  'sf_t100.0_v1.0',
                  'sf_t100.0_v1.1',
                  'sf_t100.0_v1.2',
                  'fs_t0.0_v1.0',
                  'fs_t0.0_v1.1',
                  'fs_t0.0_v1.2',
                  'fs_t100.0_v1.0',
                  'fs_t100.0_v1.1',
                  'fs_t100.0_v1.2')

OPAMP_INPUT_CORNER = 'tt_t27.0_v1.2'


OPAMP_METRICS = ('gain', 'ibias', 'phm', 'ugbw')
OPAMP_SIZING_NAMES = ('mp1', 'mn1', 'mp3', 'mn3', 'mn4', 'mn5', 'cc')


FOLD_TARGET_CORNERS = ('tt_t0.0_v1.6',
                      'tt_t0.0_v1.8',
                      'tt_t100.0_v1.6',
                      'tt_t100.0_v1.8',
                      'ss_t0.0_v1.6',
                      'ss_t0.0_v1.8',
                      'ss_t100.0_v1.6',
                      'ss_t100.0_v1.8',
                      'ff_t0.0_v1.6',
                      'ff_t0.0_v1.8',
                      'ff_t100.0_v1.6',
                      'ff_t100.0_v1.8',
                      'sf_t0.0_v1.6',
                      'sf_t0.0_v1.8',
                      'sf_t100.0_v1.6',
                      'sf_t100.0_v1.8',
                      'fs_t0.0_v1.6',
                      'fs_t0.0_v1.8',
                      'fs_t100.0_v1.6',
                      'fs_t100.0_v1.8',
                      )

FOLD_INPUT_CORNER = 'tt_t27.0_v1.8'
FOLD_METRICS = ('power', 'gain', 'cmrr', 'psrr', 'pm_dm', 'rms_noise_out_dm',
                'lg_ugb')
FOLD_SIZING_NAMES = ('L1', 'L2', 'L3', 'L4', 'L5', 'L6', 'L7',
                     'W1', 'W2', 'W3', 'W4', 'W5', 'W6', 'W7',
                     'N1', 'N2', 'N8', 'N9',
                     'MCAP', 'Cf')


STRONGARM_TARGET_CORNERS = ('tt_t0.0_v1.1',
                            'tt_t0.0_v1.2',
                            'tt_t100.0_v1.1',
                            'tt_t100.0_v1.2',
                            'ss_t0.0_v1.1',
                            'ss_t0.0_v1.2',
                            'ss_t100.0_v1.1',
                            'ss_t100.0_v1.2',
                            'ff_t0.0_v1.1',
                            'ff_t0.0_v1.2',
                            'ff_t100.0_v1.1',
                            'ff_t100.0_v1.2',
                            'sf_t0.0_v1.1',
                            'sf_t0.0_v1.2',
                            'sf_t100.0_v1.1',
                            'sf_t100.0_v1.2',
                            'fs_t0.0_v1.1',
                            'fs_t0.0_v1.2',
                            'fs_t100.0_v1.1',
                            'fs_t100.0_v1.2',
                            )

STRONGARM_INPUT_CORNER = 'tt_t27.0_v1.2'

STRONGARM_METRICS = ('Power', 'delay', 'reset', 'input_ref_noise', 'reset_val',
                     'rise_val', 'area')
STRONGARM_SIZING_NAMES = ('w1', 'w2', 'w3', 'w4', 'w6', 'w8',
                          'Cl_finger')


class PVTFC(nn.Module):
    def __init__(self,
                 in_ch: int,
                 out_ch: int,
                 layer_num: int,
                 n_out_feat: int,
                 in_feat_name: str,
                 ):
        super().__init__()
        self.input_layer = nn.Linear(in_ch, out_ch)

        self.layers = nn.ModuleList()
        self.in_feat_name = in_feat_name
        for i in range(layer_num):
            self.layers.append(nn.Linear(out_ch, out_ch))

        self.regress = nn.Linear(out_ch, n_out_feat)

    def forward(self, x):
        x = x[self.in_feat_name]

        x = self.input_layer(x)  # N, C

        for i in range(len(self.layers)):
            x = F.relu(self.layers[i](x))

        x = self.regress(x)

        return x.squeeze(dim=-1)


class Predictor(object):
    def __init__(
            self,
            ckpt_path,
            mean_std_path,
            in_ch=11,
            out_ch=512,
            layer_num=6,
            n_out_feat=120,
            n_sizings=7,
            n_metrics=4,
            in_feat_name=None,
            target_corners=None,
            metrics=None,
            sizing_names=None,
    ):
        self.ckpt_path = ckpt_path
        self.mean_std_path = mean_std_path
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.layer_num = layer_num
        self.n_out_feat = n_out_feat
        self.n_sizings = n_sizings
        self.n_metrics = n_metrics
        self.in_feat_name = in_feat_name
        self.target_corners = target_corners
        self.metrics = metrics
        self.sizing_names = sizing_names

        self.model = None
        self.mean_std = None

        self.load_model()

    def load_model(self):
        model = PVTFC(in_ch=self.in_ch,
                      out_ch=self.out_ch,
                      layer_num=self.layer_num,
                      n_out_feat=self.n_out_feat,
                      in_feat_name=self.in_feat_name)
        model.eval()
        state_dict = torch.load(self.ckpt_path, map_location=torch.device('cpu'))
        model.load_state_dict(state_dict['model'])
        self.mean_std = torch.load(self.mean_std_path, map_location=torch.device('cpu'))
        self.model = model

    def predict(
            self,
            perf: dict,
            sizing: dict,
            corner: str,
            temp: float,
            vdd: float
            ):
        metrics = self.metrics
        sizing_names = self.sizing_names
        inputs = []
        for metric in metrics:
            inputs.append(perf[metric])
        for sizing_name in sizing_names:
            inputs.append(sizing[sizing_name])

        inputs = np.array(inputs)

        inputs = (inputs - self.mean_std['mean'][self.in_feat_name]) / \
                    self.mean_std['std'][self.in_feat_name]

        outputs = self.model({self.in_feat_name: torch.tensor(inputs).float()})

        corner_name = f"{corner}_t{temp:.1f}_v{vdd:.1f}"

        idx = self.target_corners.index(corner_name)

        output = outputs[idx * self.n_metrics: (idx + 1) * self.n_metrics]

        output = output.detach().numpy() * self.mean_std['std'][
            corner_name][:self.n_metrics] + \
                 self.mean_std['mean'][corner_name][:self.n_metrics]
        output_dict = {}
        for k, metric in enumerate(metrics):
            output_dict[metric] = output[k]

        return output_dict


class OPAMPPredictor(Predictor):
    def __init__(
            self,
            ckpt_path,
            mean_std_path,
            in_ch=11,
            out_ch=512,
            layer_num=6,
            n_out_feat=120,
            n_sizings=7,
            n_metrics=4,
            in_feat_name=OPAMP_INPUT_CORNER,
            target_corners=OPAMP_TARGET_CORNERS,
            metrics=OPAMP_METRICS,
            sizing_names=OPAMP_SIZING_NAMES,
    ):
        super().__init__(
            ckpt_path=ckpt_path,
            mean_std_path=mean_std_path,
            in_ch=in_ch,
            out_ch=out_ch,
            layer_num=layer_num,
            n_out_feat=n_out_feat,
            n_sizings=n_sizings,
            n_metrics=n_metrics,
            in_feat_name=in_feat_name,
            target_corners=target_corners,
            metrics=metrics,
            sizing_names=sizing_names,
        )


class FOLDPredictor(Predictor):
    def __init__(
            self,
            ckpt_path,
            mean_std_path,
            in_ch=27,
            out_ch=512,
            layer_num=6,
            n_out_feat=140,
            n_sizings=20,
            n_metrics=7,
            in_feat_name=FOLD_INPUT_CORNER,
            target_corners=FOLD_TARGET_CORNERS,
            metrics=FOLD_METRICS,
            sizing_names=FOLD_SIZING_NAMES,
    ):
        super().__init__(
            ckpt_path=ckpt_path,
            mean_std_path=mean_std_path,
            in_ch=in_ch,
            out_ch=out_ch,
            layer_num=layer_num,
            n_out_feat=n_out_feat,
            n_sizings=n_sizings,
            n_metrics=n_metrics,
            in_feat_name=in_feat_name,
            target_corners=target_corners,
            metrics=metrics,
            sizing_names=sizing_names,
        )


class STRONGARMPredictor(Predictor):
    def __init__(
            self,
            ckpt_path,
            mean_std_path,
            in_ch=14,
            out_ch=512,
            layer_num=6,
            n_out_feat=140,
            n_sizings=7,
            n_metrics=7,
            in_feat_name=STRONGARM_INPUT_CORNER,
            target_corners=STRONGARM_TARGET_CORNERS,
            metrics=STRONGARM_METRICS,
            sizing_names=STRONGARM_SIZING_NAMES,
    ):
        super().__init__(
            ckpt_path=ckpt_path,
            mean_std_path=mean_std_path,
            in_ch=in_ch,
            out_ch=out_ch,
            layer_num=layer_num,
            n_out_feat=n_out_feat,
            n_sizings=n_sizings,
            n_metrics=n_metrics,
            in_feat_name=in_feat_name,
            target_corners=target_corners,
            metrics=metrics,
            sizing_names=sizing_names,
        )


if __name__ == '__main__':
    # import pdb
    # pdb.set_trace()
    opamp_predictor = OPAMPPredictor(
        ckpt_path='./opamp_60e97520.pt',
        mean_std_path='./opamp_60e97520_mean_std.pt',
    )

    perf0 = {
        'gain': 1.25874120e+01,
        'ibias': 1.73677800e-04,
        'phm': 2.36779190e+01,
        'ugbw': 3.66280501e+07,
    }
    sizing0 = {
        'mp1': 3.80794719e+01,
        'mn1': 9.51207206e+01,
        'mp3': 7.34673975e+01,
        'mn3': 6.02671917e+01,
        'mn4': 1.64458458e+01,
        'mn5': 1.64434589e+01,
        'cc': 6.75027874e-13
    }

    perf1 = {
        'gain': 2.01990830e+02,
        'ibias': 6.81637840e-05,
        'phm': 1.17113481e+01,
        'ugbw': 1.04102115e+07,
    }
    sizing1 = {
        'mp1': 2.78408074e+01,
        'mn1': 5.03635869e+01,
        'mp3': 4.44335296e+01,
        'mn3': 4.78976294e+01,
        'mn4': 6.19879119e+01,
        'mn5': 2.38651828e+01,
        'cc': 1.43357101e-12
    }

    output_perf = opamp_predictor.predict(
        perf=perf1,
        sizing=sizing1,
        corner='fs',
        temp=100,
        vdd=1.2
    )
    print(output_perf)

    fold_predictor = FOLDPredictor(
        ckpt_path='./fold_b910a775.pt',
        mean_std_path='./fold_b910a775_mean_std.pt',
    )


    fold_perf0 = {'power': 0.0009450954000000001, 'gain': 93.75359290131087, 'cmrr': 94.03283812677019, 'psrr': 121.21730401438766, 'pm_dm': 0, 'rms_noise_out_dm': 0.000451972, 'lg_ugb': 0}
    fold_sizing0 = {'L1': 8.616630190610885e-07, 'L2': 1.9103001165390015e-06,
                'L3': 1.512228922843933e-06, 'L4': 1.2695584738254547e-06, 'L5': 4.6395393311977384e-07, 'L6': 4.6391005277633667e-07, 'L7': 2.8571219503879544e-07, 'W1': 0.00012995853698730467, 'W2': 9.026298065185545e-05, 'W3': 0.00010628094406127927, 'W4': 3.3227338027954097e-06, 'W5': 0.00014549369865417478, 'W6': 0.00012490660995483397, 'W7': 3.2039904098510734e-05, 'N1': 4.45467409491539, 'N2': 4.484685599803925, 'N8': 6.780602812767029, 'N9': 10.970372200012207, 'MCAP': 9.20695549249649e-13, 'Cf': 2.9831683754920957e-12}


    fold_output_perf = fold_predictor.predict(
        perf=fold_perf0,
        sizing=fold_sizing0,
        corner='fs',
        temp=100,
        vdd=1.6
    )
    print(fold_output_perf)




    strongarm_predictor = STRONGARMPredictor(
        ckpt_path='./strongarm_143ddcc3.pt',
        mean_std_path='./strongarm_143ddcc3_mean_std.pt',
    )


    strongarm_perf0 = {'Power': 2.65107e-06, 'delay': 1.06965e-08, 'reset': 2.28001e-09, 'input_ref_noise': 7.34304e-05, 'reset_val': 2.47109e-07, 'rise_val': 1.2, 'area': 5.36627018995285e-11}
    strongarms_sizing0 = {'w1': 1.8864607191681863e-05, 'w2': 4.7546560330390934e-05,
             'w3': 3.665865702152252e-05, 'w4': 3.0021220234632495e-05, 'w6': 7.986608126759528e-06, 'w8': 7.985407927036285e-06, 'Cl_finger': 26.844250857830048}
    strongarm_perf1 = {'Power': 5.22678e-06, 'delay': 7.09205e-09, 'reset': 1.23755e-09, 'input_ref_noise': 5.32441e-05, 'reset_val': 6.26552e-07, 'rise_val': 1.2, 'area': 7.212794614112376e-11}
    strongarms_sizing1 = {'w1': 4.3338247671127324e-05, 'w2': 3.0143504118919375e-05, 'w3': 3.546785119771957e-05, 'w4': 1.2446961051225662e-06, 'w6': 4.850211217284203e-05, 'w8': 4.1658994681835175e-05, 'Cl_finger': 71.57833993434906}

    strongarm_output_perf = strongarm_predictor.predict(
        perf=strongarm_perf1,
        sizing=strongarms_sizing1,
        corner='fs',
        temp=100,
        vdd=1.1
    )
    print(strongarm_output_perf)
